1#!/usr/bin/env python 2# pylint: skip-file 3 4 5from __future__ import absolute_import, division, print_function 6from salt.ext.tornado import netutil 7from salt.ext.tornado.escape import json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str 8from salt.ext.tornado import gen 9from salt.ext.tornado.http1connection import HTTP1Connection 10from salt.ext.tornado.httpserver import HTTPServer 11from salt.ext.tornado.httputil import HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine 12from salt.ext.tornado.iostream import IOStream 13from salt.ext.tornado.log import gen_log 14from salt.ext.tornado.netutil import ssl_options_to_context 15from salt.ext.tornado.simple_httpclient import SimpleAsyncHTTPClient 16from salt.ext.tornado.testing import AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test 17from salt.ext.tornado.test.util import unittest, skipOnTravis 18from salt.ext.tornado.web import Application, RequestHandler, asynchronous, stream_request_body 19from contextlib import closing 20import datetime 21import gzip 22import os 23import shutil 24import socket 25import ssl 26import sys 27import tempfile 28from io import BytesIO 29 30 31def read_stream_body(stream, callback): 32 """Reads an HTTP response from `stream` and runs callback with its 33 headers and body.""" 34 chunks = [] 35 36 class Delegate(HTTPMessageDelegate): 37 def headers_received(self, start_line, headers): 38 self.headers = headers 39 40 def data_received(self, chunk): 41 chunks.append(chunk) 42 43 def finish(self): 44 callback((self.headers, b''.join(chunks))) 45 conn = HTTP1Connection(stream, True) 46 conn.read_response(Delegate()) 47 48 49class HandlerBaseTestCase(AsyncHTTPTestCase): 50 def get_app(self): 51 return Application([('/', self.__class__.Handler)]) 52 53 def fetch_json(self, *args, **kwargs): 54 response = self.fetch(*args, **kwargs) 55 response.rethrow() 56 return json_decode(response.body) 57 58 59class HelloWorldRequestHandler(RequestHandler): 60 def initialize(self, protocol="http"): 61 self.expected_protocol = protocol 62 63 def get(self): 64 if self.request.protocol != self.expected_protocol: 65 raise Exception("unexpected protocol") 66 self.finish("Hello world") 67 68 def post(self): 69 self.finish("Got %d bytes in POST" % len(self.request.body)) 70 71 72# In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2 73# ClientHello messages, which are rejected by SSLv3 and TLSv1 74# servers. Note that while the OPENSSL_VERSION_INFO was formally 75# introduced in python3.2, it was present but undocumented in 76# python 2.7 77skipIfOldSSL = unittest.skipIf( 78 getattr(ssl, 'OPENSSL_VERSION_INFO', (0, 0)) < (1, 0), 79 "old version of ssl module and/or openssl") 80 81 82class BaseSSLTest(AsyncHTTPSTestCase): 83 def get_app(self): 84 return Application([('/', HelloWorldRequestHandler, 85 dict(protocol="https"))]) 86 87 88class SSLTestMixin(object): 89 def get_ssl_options(self): 90 return dict(ssl_version=self.get_ssl_version(), # type: ignore 91 **AsyncHTTPSTestCase.get_ssl_options()) 92 93 def get_ssl_version(self): 94 raise NotImplementedError() 95 96 def test_ssl(self): 97 response = self.fetch('/') 98 self.assertEqual(response.body, b"Hello world") 99 100 def test_large_post(self): 101 response = self.fetch('/', 102 method='POST', 103 body='A' * 5000) 104 self.assertEqual(response.body, b"Got 5000 bytes in POST") 105 106 def test_non_ssl_request(self): 107 # Make sure the server closes the connection when it gets a non-ssl 108 # connection, rather than waiting for a timeout or otherwise 109 # misbehaving. 110 with ExpectLog(gen_log, '(SSL Error|uncaught exception)'): 111 with ExpectLog(gen_log, 'Uncaught exception', required=False): 112 self.http_client.fetch( 113 self.get_url("/").replace('https:', 'http:'), 114 self.stop, 115 request_timeout=3600, 116 connect_timeout=3600) 117 response = self.wait() 118 self.assertEqual(response.code, 599) 119 120 def test_error_logging(self): 121 # No stack traces are logged for SSL errors. 122 with ExpectLog(gen_log, 'SSL Error') as expect_log: 123 self.http_client.fetch( 124 self.get_url("/").replace("https:", "http:"), 125 self.stop) 126 response = self.wait() 127 self.assertEqual(response.code, 599) 128 self.assertFalse(expect_log.logged_stack) 129 130# Python's SSL implementation differs significantly between versions. 131# For example, SSLv3 and TLSv1 throw an exception if you try to read 132# from the socket before the handshake is complete, but the default 133# of SSLv23 allows it. 134 135 136class SSLv23Test(BaseSSLTest, SSLTestMixin): 137 def get_ssl_version(self): 138 return ssl.PROTOCOL_SSLv23 139 140 141@skipIfOldSSL 142class SSLv3Test(BaseSSLTest, SSLTestMixin): 143 def get_ssl_version(self): 144 return ssl.PROTOCOL_SSLv3 145 146 147@skipIfOldSSL 148class TLSv1Test(BaseSSLTest, SSLTestMixin): 149 def get_ssl_version(self): 150 return ssl.PROTOCOL_TLSv1 151 152 153@unittest.skipIf(not hasattr(ssl, 'SSLContext'), 'ssl.SSLContext not present') 154class SSLContextTest(BaseSSLTest, SSLTestMixin): 155 def get_ssl_options(self): 156 context = ssl_options_to_context( 157 AsyncHTTPSTestCase.get_ssl_options(self)) 158 assert isinstance(context, ssl.SSLContext) 159 return context 160 161 162class BadSSLOptionsTest(unittest.TestCase): 163 def test_missing_arguments(self): 164 application = Application() 165 self.assertRaises(KeyError, HTTPServer, application, ssl_options={ 166 "keyfile": "/__missing__.crt", 167 }) 168 169 def test_missing_key(self): 170 """A missing SSL key should cause an immediate exception.""" 171 172 application = Application() 173 module_dir = os.path.dirname(__file__) 174 existing_certificate = os.path.join(module_dir, 'test.crt') 175 existing_key = os.path.join(module_dir, 'test.key') 176 177 self.assertRaises((ValueError, IOError), 178 HTTPServer, application, ssl_options={ 179 "certfile": "/__mising__.crt", 180 }) 181 self.assertRaises((ValueError, IOError), 182 HTTPServer, application, ssl_options={ 183 "certfile": existing_certificate, 184 "keyfile": "/__missing__.key" 185 }) 186 187 # This actually works because both files exist 188 HTTPServer(application, ssl_options={ 189 "certfile": existing_certificate, 190 "keyfile": existing_key, 191 }) 192 193 194class MultipartTestHandler(RequestHandler): 195 def post(self): 196 self.finish({"header": self.request.headers["X-Header-Encoding-Test"], 197 "argument": self.get_argument("argument"), 198 "filename": self.request.files["files"][0].filename, 199 "filebody": _unicode(self.request.files["files"][0]["body"]), 200 }) 201 202 203# This test is also called from wsgi_test 204class HTTPConnectionTest(AsyncHTTPTestCase): 205 def get_handlers(self): 206 return [("/multipart", MultipartTestHandler), 207 ("/hello", HelloWorldRequestHandler)] 208 209 def get_app(self): 210 return Application(self.get_handlers()) 211 212 def raw_fetch(self, headers, body, newline=b"\r\n"): 213 with closing(IOStream(socket.socket())) as stream: 214 stream.connect(('127.0.0.1', self.get_http_port()), self.stop) 215 self.wait() 216 stream.write( 217 newline.join(headers + 218 [utf8("Content-Length: %d" % len(body))]) + 219 newline + newline + body) 220 read_stream_body(stream, self.stop) 221 headers, body = self.wait() 222 return body 223 224 def test_multipart_form(self): 225 # Encodings here are tricky: Headers are latin1, bodies can be 226 # anything (we use utf8 by default). 227 response = self.raw_fetch([ 228 b"POST /multipart HTTP/1.0", 229 b"Content-Type: multipart/form-data; boundary=1234567890", 230 b"X-Header-encoding-test: \xe9", 231 ], 232 b"\r\n".join([ 233 b"Content-Disposition: form-data; name=argument", 234 b"", 235 u"\u00e1".encode("utf-8"), 236 b"--1234567890", 237 u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode("utf8"), 238 b"", 239 u"\u00fa".encode("utf-8"), 240 b"--1234567890--", 241 b"", 242 ])) 243 data = json_decode(response) 244 self.assertEqual(u"\u00e9", data["header"]) 245 self.assertEqual(u"\u00e1", data["argument"]) 246 self.assertEqual(u"\u00f3", data["filename"]) 247 self.assertEqual(u"\u00fa", data["filebody"]) 248 249 def test_newlines(self): 250 # We support both CRLF and bare LF as line separators. 251 for newline in (b"\r\n", b"\n"): 252 response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", 253 newline=newline) 254 self.assertEqual(response, b'Hello world') 255 256 def test_100_continue(self): 257 # Run through a 100-continue interaction by hand: 258 # When given Expect: 100-continue, we get a 100 response after the 259 # headers, and then the real response after the body. 260 stream = IOStream(socket.socket(), io_loop=self.io_loop) 261 stream.connect(("127.0.0.1", self.get_http_port()), callback=self.stop) 262 self.wait() 263 stream.write(b"\r\n".join([b"POST /hello HTTP/1.1", 264 b"Content-Length: 1024", 265 b"Expect: 100-continue", 266 b"Connection: close", 267 b"\r\n"]), callback=self.stop) 268 self.wait() 269 stream.read_until(b"\r\n\r\n", self.stop) 270 data = self.wait() 271 self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) 272 stream.write(b"a" * 1024) 273 stream.read_until(b"\r\n", self.stop) 274 first_line = self.wait() 275 self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) 276 stream.read_until(b"\r\n\r\n", self.stop) 277 header_data = self.wait() 278 headers = HTTPHeaders.parse(native_str(header_data.decode('latin1'))) 279 stream.read_bytes(int(headers["Content-Length"]), self.stop) 280 body = self.wait() 281 self.assertEqual(body, b"Got 1024 bytes in POST") 282 stream.close() 283 284 285class EchoHandler(RequestHandler): 286 def get(self): 287 self.write(recursive_unicode(self.request.arguments)) 288 289 def post(self): 290 self.write(recursive_unicode(self.request.arguments)) 291 292 293class TypeCheckHandler(RequestHandler): 294 def prepare(self): 295 self.errors = {} 296 fields = [ 297 ('method', str), 298 ('uri', str), 299 ('version', str), 300 ('remote_ip', str), 301 ('protocol', str), 302 ('host', str), 303 ('path', str), 304 ('query', str), 305 ] 306 for field, expected_type in fields: 307 self.check_type(field, getattr(self.request, field), expected_type) 308 309 self.check_type('header_key', list(self.request.headers.keys())[0], str) 310 self.check_type('header_value', list(self.request.headers.values())[0], str) 311 312 self.check_type('cookie_key', list(self.request.cookies.keys())[0], str) 313 self.check_type('cookie_value', list(self.request.cookies.values())[0].value, str) 314 # secure cookies 315 316 self.check_type('arg_key', list(self.request.arguments.keys())[0], str) 317 self.check_type('arg_value', list(self.request.arguments.values())[0][0], bytes) 318 319 def post(self): 320 self.check_type('body', self.request.body, bytes) 321 self.write(self.errors) 322 323 def get(self): 324 self.write(self.errors) 325 326 def check_type(self, name, obj, expected_type): 327 actual_type = type(obj) 328 if expected_type != actual_type: 329 self.errors[name] = "expected %s, got %s" % (expected_type, 330 actual_type) 331 332 333class HTTPServerTest(AsyncHTTPTestCase): 334 def get_app(self): 335 return Application([("/echo", EchoHandler), 336 ("/typecheck", TypeCheckHandler), 337 ("//doubleslash", EchoHandler), 338 ]) 339 340 def test_query_string_encoding(self): 341 response = self.fetch("/echo?foo=%C3%A9") 342 data = json_decode(response.body) 343 self.assertEqual(data, {u"foo": [u"\u00e9"]}) 344 345 def test_empty_query_string(self): 346 response = self.fetch("/echo?foo=&foo=") 347 data = json_decode(response.body) 348 self.assertEqual(data, {u"foo": [u"", u""]}) 349 350 def test_empty_post_parameters(self): 351 response = self.fetch("/echo", method="POST", body="foo=&bar=") 352 data = json_decode(response.body) 353 self.assertEqual(data, {u"foo": [u""], u"bar": [u""]}) 354 355 def test_types(self): 356 headers = {"Cookie": "foo=bar"} 357 response = self.fetch("/typecheck?foo=bar", headers=headers) 358 data = json_decode(response.body) 359 self.assertEqual(data, {}) 360 361 response = self.fetch("/typecheck", method="POST", body="foo=bar", headers=headers) 362 data = json_decode(response.body) 363 self.assertEqual(data, {}) 364 365 def test_double_slash(self): 366 # urlparse.urlsplit (which tornado.httpserver used to use 367 # incorrectly) would parse paths beginning with "//" as 368 # protocol-relative urls. 369 response = self.fetch("//doubleslash") 370 self.assertEqual(200, response.code) 371 self.assertEqual(json_decode(response.body), {}) 372 373 def test_malformed_body(self): 374 # parse_qs is pretty forgiving, but it will fail on python 3 375 # if the data is not utf8. On python 2 parse_qs will work, 376 # but then the recursive_unicode call in EchoHandler will 377 # fail. 378 if str is bytes: 379 return 380 with ExpectLog(gen_log, 'Invalid x-www-form-urlencoded body'): 381 response = self.fetch( 382 '/echo', method="POST", 383 headers={'Content-Type': 'application/x-www-form-urlencoded'}, 384 body=b'\xe9') 385 self.assertEqual(200, response.code) 386 self.assertEqual(b'{}', response.body) 387 388 389class HTTPServerRawTest(AsyncHTTPTestCase): 390 def get_app(self): 391 return Application([ 392 ('/echo', EchoHandler), 393 ]) 394 395 def setUp(self): 396 super(HTTPServerRawTest, self).setUp() 397 self.stream = IOStream(socket.socket()) 398 self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) 399 self.wait() 400 401 def tearDown(self): 402 self.stream.close() 403 super(HTTPServerRawTest, self).tearDown() 404 405 def test_empty_request(self): 406 self.stream.close() 407 self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) 408 self.wait() 409 410 def test_malformed_first_line(self): 411 with ExpectLog(gen_log, '.*Malformed HTTP request line'): 412 self.stream.write(b'asdf\r\n\r\n') 413 # TODO: need an async version of ExpectLog so we don't need 414 # hard-coded timeouts here. 415 self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), 416 self.stop) 417 self.wait() 418 419 def test_malformed_headers(self): 420 with ExpectLog(gen_log, '.*Malformed HTTP headers'): 421 self.stream.write(b'GET / HTTP/1.0\r\nasdf\r\n\r\n') 422 self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), 423 self.stop) 424 self.wait() 425 426 def test_chunked_request_body(self): 427 # Chunked requests are not widely supported and we don't have a way 428 # to generate them in AsyncHTTPClient, but HTTPServer will read them. 429 self.stream.write(b"""\ 430POST /echo HTTP/1.1 431Transfer-Encoding: chunked 432Content-Type: application/x-www-form-urlencoded 433 4344 435foo= 4363 437bar 4380 439 440""".replace(b"\n", b"\r\n")) 441 read_stream_body(self.stream, self.stop) 442 headers, response = self.wait() 443 self.assertEqual(json_decode(response), {u'foo': [u'bar']}) 444 445 def test_chunked_request_uppercase(self): 446 # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is 447 # case-insensitive. 448 self.stream.write(b"""\ 449POST /echo HTTP/1.1 450Transfer-Encoding: Chunked 451Content-Type: application/x-www-form-urlencoded 452 4534 454foo= 4553 456bar 4570 458 459""".replace(b"\n", b"\r\n")) 460 read_stream_body(self.stream, self.stop) 461 headers, response = self.wait() 462 self.assertEqual(json_decode(response), {u'foo': [u'bar']}) 463 464 def test_invalid_content_length(self): 465 with ExpectLog(gen_log, '.*Only integer Content-Length is allowed'): 466 self.stream.write(b"""\ 467POST /echo HTTP/1.1 468Content-Length: foo 469 470bar 471 472""".replace(b"\n", b"\r\n")) 473 self.stream.read_until_close(self.stop) 474 self.wait() 475 476 477class XHeaderTest(HandlerBaseTestCase): 478 class Handler(RequestHandler): 479 def get(self): 480 self.write(dict(remote_ip=self.request.remote_ip, 481 remote_protocol=self.request.protocol)) 482 483 def get_httpserver_options(self): 484 return dict(xheaders=True, trusted_downstream=['5.5.5.5']) 485 486 def test_ip_headers(self): 487 self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1") 488 489 valid_ipv4 = {"X-Real-IP": "4.4.4.4"} 490 self.assertEqual( 491 self.fetch_json("/", headers=valid_ipv4)["remote_ip"], 492 "4.4.4.4") 493 494 valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"} 495 self.assertEqual( 496 self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], 497 "4.4.4.4") 498 499 valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"} 500 self.assertEqual( 501 self.fetch_json("/", headers=valid_ipv6)["remote_ip"], 502 "2620:0:1cfe:face:b00c::3") 503 504 valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"} 505 self.assertEqual( 506 self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"], 507 "2620:0:1cfe:face:b00c::3") 508 509 invalid_chars = {"X-Real-IP": "4.4.4.4<script>"} 510 self.assertEqual( 511 self.fetch_json("/", headers=invalid_chars)["remote_ip"], 512 "127.0.0.1") 513 514 invalid_chars_list = {"X-Forwarded-For": "4.4.4.4, 5.5.5.5<script>"} 515 self.assertEqual( 516 self.fetch_json("/", headers=invalid_chars_list)["remote_ip"], 517 "127.0.0.1") 518 519 invalid_host = {"X-Real-IP": "www.google.com"} 520 self.assertEqual( 521 self.fetch_json("/", headers=invalid_host)["remote_ip"], 522 "127.0.0.1") 523 524 def test_trusted_downstream(self): 525 526 valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4, 5.5.5.5"} 527 self.assertEqual( 528 self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], 529 "4.4.4.4") 530 531 def test_scheme_headers(self): 532 self.assertEqual(self.fetch_json("/")["remote_protocol"], "http") 533 534 https_scheme = {"X-Scheme": "https"} 535 self.assertEqual( 536 self.fetch_json("/", headers=https_scheme)["remote_protocol"], 537 "https") 538 539 https_forwarded = {"X-Forwarded-Proto": "https"} 540 self.assertEqual( 541 self.fetch_json("/", headers=https_forwarded)["remote_protocol"], 542 "https") 543 544 bad_forwarded = {"X-Forwarded-Proto": "unknown"} 545 self.assertEqual( 546 self.fetch_json("/", headers=bad_forwarded)["remote_protocol"], 547 "http") 548 549 550class SSLXHeaderTest(AsyncHTTPSTestCase, HandlerBaseTestCase): 551 def get_app(self): 552 return Application([('/', XHeaderTest.Handler)]) 553 554 def get_httpserver_options(self): 555 output = super(SSLXHeaderTest, self).get_httpserver_options() 556 output['xheaders'] = True 557 return output 558 559 def test_request_without_xprotocol(self): 560 self.assertEqual(self.fetch_json("/")["remote_protocol"], "https") 561 562 http_scheme = {"X-Scheme": "http"} 563 self.assertEqual( 564 self.fetch_json("/", headers=http_scheme)["remote_protocol"], "http") 565 566 bad_scheme = {"X-Scheme": "unknown"} 567 self.assertEqual( 568 self.fetch_json("/", headers=bad_scheme)["remote_protocol"], "https") 569 570 571class ManualProtocolTest(HandlerBaseTestCase): 572 class Handler(RequestHandler): 573 def get(self): 574 self.write(dict(protocol=self.request.protocol)) 575 576 def get_httpserver_options(self): 577 return dict(protocol='https') 578 579 def test_manual_protocol(self): 580 self.assertEqual(self.fetch_json('/')['protocol'], 'https') 581 582 583@unittest.skipIf(not hasattr(socket, 'AF_UNIX') or sys.platform == 'cygwin', 584 "unix sockets not supported on this platform") 585class UnixSocketTest(AsyncTestCase): 586 """HTTPServers can listen on Unix sockets too. 587 588 Why would you want to do this? Nginx can proxy to backends listening 589 on unix sockets, for one thing (and managing a namespace for unix 590 sockets can be easier than managing a bunch of TCP port numbers). 591 592 Unfortunately, there's no way to specify a unix socket in a url for 593 an HTTP client, so we have to test this by hand. 594 """ 595 def setUp(self): 596 super(UnixSocketTest, self).setUp() 597 self.tmpdir = tempfile.mkdtemp() 598 self.sockfile = os.path.join(self.tmpdir, "test.sock") 599 sock = netutil.bind_unix_socket(self.sockfile) 600 app = Application([("/hello", HelloWorldRequestHandler)]) 601 self.server = HTTPServer(app, io_loop=self.io_loop) 602 self.server.add_socket(sock) 603 self.stream = IOStream(socket.socket(socket.AF_UNIX), io_loop=self.io_loop) 604 self.stream.connect(self.sockfile, self.stop) 605 self.wait() 606 607 def tearDown(self): 608 self.stream.close() 609 self.server.stop() 610 shutil.rmtree(self.tmpdir) 611 super(UnixSocketTest, self).tearDown() 612 613 def test_unix_socket(self): 614 self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") 615 self.stream.read_until(b"\r\n", self.stop) 616 response = self.wait() 617 self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") 618 self.stream.read_until(b"\r\n\r\n", self.stop) 619 headers = HTTPHeaders.parse(self.wait().decode('latin1')) 620 self.stream.read_bytes(int(headers["Content-Length"]), self.stop) 621 body = self.wait() 622 self.assertEqual(body, b"Hello world") 623 624 def test_unix_socket_bad_request(self): 625 # Unix sockets don't have remote addresses so they just return an 626 # empty string. 627 with ExpectLog(gen_log, "Malformed HTTP message from"): 628 self.stream.write(b"garbage\r\n\r\n") 629 self.stream.read_until_close(self.stop) 630 response = self.wait() 631 self.assertEqual(response, b"") 632 633 634class KeepAliveTest(AsyncHTTPTestCase): 635 """Tests various scenarios for HTTP 1.1 keep-alive support. 636 637 These tests don't use AsyncHTTPClient because we want to control 638 connection reuse and closing. 639 """ 640 def get_app(self): 641 class HelloHandler(RequestHandler): 642 def get(self): 643 self.finish('Hello world') 644 645 def post(self): 646 self.finish('Hello world') 647 648 class LargeHandler(RequestHandler): 649 def get(self): 650 # 512KB should be bigger than the socket buffers so it will 651 # be written out in chunks. 652 self.write(''.join(chr(i % 256) * 1024 for i in range(512))) 653 654 class FinishOnCloseHandler(RequestHandler): 655 @asynchronous 656 def get(self): 657 self.flush() 658 659 def on_connection_close(self): 660 # This is not very realistic, but finishing the request 661 # from the close callback has the right timing to mimic 662 # some errors seen in the wild. 663 self.finish('closed') 664 665 return Application([('/', HelloHandler), 666 ('/large', LargeHandler), 667 ('/finish_on_close', FinishOnCloseHandler)]) 668 669 def setUp(self): 670 super(KeepAliveTest, self).setUp() 671 self.http_version = b'HTTP/1.1' 672 673 def tearDown(self): 674 # We just closed the client side of the socket; let the IOLoop run 675 # once to make sure the server side got the message. 676 self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) 677 self.wait() 678 679 if hasattr(self, 'stream'): 680 self.stream.close() 681 super(KeepAliveTest, self).tearDown() 682 683 # The next few methods are a crude manual http client 684 def connect(self): 685 self.stream = IOStream(socket.socket(), io_loop=self.io_loop) 686 self.stream.connect(('127.0.0.1', self.get_http_port()), self.stop) 687 self.wait() 688 689 def read_headers(self): 690 self.stream.read_until(b'\r\n', self.stop) 691 first_line = self.wait() 692 self.assertTrue(first_line.startswith(b'HTTP/1.1 200'), first_line) 693 self.stream.read_until(b'\r\n\r\n', self.stop) 694 header_bytes = self.wait() 695 headers = HTTPHeaders.parse(header_bytes.decode('latin1')) 696 return headers 697 698 def read_response(self): 699 self.headers = self.read_headers() 700 self.stream.read_bytes(int(self.headers['Content-Length']), self.stop) 701 body = self.wait() 702 self.assertEqual(b'Hello world', body) 703 704 def close(self): 705 self.stream.close() 706 del self.stream 707 708 def test_two_requests(self): 709 self.connect() 710 self.stream.write(b'GET / HTTP/1.1\r\n\r\n') 711 self.read_response() 712 self.stream.write(b'GET / HTTP/1.1\r\n\r\n') 713 self.read_response() 714 self.close() 715 716 def test_request_close(self): 717 self.connect() 718 self.stream.write(b'GET / HTTP/1.1\r\nConnection: close\r\n\r\n') 719 self.read_response() 720 self.stream.read_until_close(callback=self.stop) 721 data = self.wait() 722 self.assertTrue(not data) 723 self.close() 724 725 # keepalive is supported for http 1.0 too, but it's opt-in 726 def test_http10(self): 727 self.http_version = b'HTTP/1.0' 728 self.connect() 729 self.stream.write(b'GET / HTTP/1.0\r\n\r\n') 730 self.read_response() 731 self.stream.read_until_close(callback=self.stop) 732 data = self.wait() 733 self.assertTrue(not data) 734 self.assertTrue('Connection' not in self.headers) 735 self.close() 736 737 def test_http10_keepalive(self): 738 self.http_version = b'HTTP/1.0' 739 self.connect() 740 self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') 741 self.read_response() 742 self.assertEqual(self.headers['Connection'], 'Keep-Alive') 743 self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') 744 self.read_response() 745 self.assertEqual(self.headers['Connection'], 'Keep-Alive') 746 self.close() 747 748 def test_http10_keepalive_extra_crlf(self): 749 self.http_version = b'HTTP/1.0' 750 self.connect() 751 self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n\r\n') 752 self.read_response() 753 self.assertEqual(self.headers['Connection'], 'Keep-Alive') 754 self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') 755 self.read_response() 756 self.assertEqual(self.headers['Connection'], 'Keep-Alive') 757 self.close() 758 759 def test_pipelined_requests(self): 760 self.connect() 761 self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') 762 self.read_response() 763 self.read_response() 764 self.close() 765 766 def test_pipelined_cancel(self): 767 self.connect() 768 self.stream.write(b'GET / HTTP/1.1\r\n\r\nGET / HTTP/1.1\r\n\r\n') 769 # only read once 770 self.read_response() 771 self.close() 772 773 def test_cancel_during_download(self): 774 self.connect() 775 self.stream.write(b'GET /large HTTP/1.1\r\n\r\n') 776 self.read_headers() 777 self.stream.read_bytes(1024, self.stop) 778 self.wait() 779 self.close() 780 781 def test_finish_while_closed(self): 782 self.connect() 783 self.stream.write(b'GET /finish_on_close HTTP/1.1\r\n\r\n') 784 self.read_headers() 785 self.close() 786 787 def test_keepalive_chunked(self): 788 self.http_version = b'HTTP/1.0' 789 self.connect() 790 self.stream.write(b'POST / HTTP/1.0\r\n' 791 b'Connection: keep-alive\r\n' 792 b'Transfer-Encoding: chunked\r\n' 793 b'\r\n' 794 b'0\r\n' 795 b'\r\n') 796 self.read_response() 797 self.assertEqual(self.headers['Connection'], 'Keep-Alive') 798 self.stream.write(b'GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n') 799 self.read_response() 800 self.assertEqual(self.headers['Connection'], 'Keep-Alive') 801 self.close() 802 803 804class GzipBaseTest(object): 805 def get_app(self): 806 return Application([('/', EchoHandler)]) 807 808 def post_gzip(self, body): 809 bytesio = BytesIO() 810 gzip_file = gzip.GzipFile(mode='w', fileobj=bytesio) 811 gzip_file.write(utf8(body)) 812 gzip_file.close() 813 compressed_body = bytesio.getvalue() 814 return self.fetch('/', method='POST', body=compressed_body, 815 headers={'Content-Encoding': 'gzip'}) 816 817 def test_uncompressed(self): 818 response = self.fetch('/', method='POST', body='foo=bar') 819 self.assertEquals(json_decode(response.body), {u'foo': [u'bar']}) 820 821 822class GzipTest(GzipBaseTest, AsyncHTTPTestCase): 823 def get_httpserver_options(self): 824 return dict(decompress_request=True) 825 826 def test_gzip(self): 827 response = self.post_gzip('foo=bar') 828 self.assertEquals(json_decode(response.body), {u'foo': [u'bar']}) 829 830 831class GzipUnsupportedTest(GzipBaseTest, AsyncHTTPTestCase): 832 def test_gzip_unsupported(self): 833 # Gzip support is opt-in; without it the server fails to parse 834 # the body (but parsing form bodies is currently just a log message, 835 # not a fatal error). 836 with ExpectLog(gen_log, "Unsupported Content-Encoding"): 837 response = self.post_gzip('foo=bar') 838 self.assertEquals(json_decode(response.body), {}) 839 840 841class StreamingChunkSizeTest(AsyncHTTPTestCase): 842 # 50 characters long, and repetitive so it can be compressed. 843 BODY = b'01234567890123456789012345678901234567890123456789' 844 CHUNK_SIZE = 16 845 846 def get_http_client(self): 847 # body_producer doesn't work on curl_httpclient, so override the 848 # configured AsyncHTTPClient implementation. 849 return SimpleAsyncHTTPClient(io_loop=self.io_loop) 850 851 def get_httpserver_options(self): 852 return dict(chunk_size=self.CHUNK_SIZE, decompress_request=True) 853 854 class MessageDelegate(HTTPMessageDelegate): 855 def __init__(self, connection): 856 self.connection = connection 857 858 def headers_received(self, start_line, headers): 859 self.chunk_lengths = [] 860 861 def data_received(self, chunk): 862 self.chunk_lengths.append(len(chunk)) 863 864 def finish(self): 865 response_body = utf8(json_encode(self.chunk_lengths)) 866 self.connection.write_headers( 867 ResponseStartLine('HTTP/1.1', 200, 'OK'), 868 HTTPHeaders({'Content-Length': str(len(response_body))})) 869 self.connection.write(response_body) 870 self.connection.finish() 871 872 def get_app(self): 873 class App(HTTPServerConnectionDelegate): 874 def start_request(self, server_conn, request_conn): 875 return StreamingChunkSizeTest.MessageDelegate(request_conn) 876 return App() 877 878 def fetch_chunk_sizes(self, **kwargs): 879 response = self.fetch('/', method='POST', **kwargs) 880 response.rethrow() 881 chunks = json_decode(response.body) 882 self.assertEqual(len(self.BODY), sum(chunks)) 883 for chunk_size in chunks: 884 self.assertLessEqual(chunk_size, self.CHUNK_SIZE, 885 'oversized chunk: ' + str(chunks)) 886 self.assertGreater(chunk_size, 0, 887 'empty chunk: ' + str(chunks)) 888 return chunks 889 890 def compress(self, body): 891 bytesio = BytesIO() 892 gzfile = gzip.GzipFile(mode='w', fileobj=bytesio) 893 gzfile.write(body) 894 gzfile.close() 895 compressed = bytesio.getvalue() 896 if len(compressed) >= len(body): 897 raise Exception("body did not shrink when compressed") 898 return compressed 899 900 def test_regular_body(self): 901 chunks = self.fetch_chunk_sizes(body=self.BODY) 902 # Without compression we know exactly what to expect. 903 self.assertEqual([16, 16, 16, 2], chunks) 904 905 def test_compressed_body(self): 906 self.fetch_chunk_sizes(body=self.compress(self.BODY), 907 headers={'Content-Encoding': 'gzip'}) 908 # Compression creates irregular boundaries so the assertions 909 # in fetch_chunk_sizes are as specific as we can get. 910 911 def test_chunked_body(self): 912 def body_producer(write): 913 write(self.BODY[:20]) 914 write(self.BODY[20:]) 915 chunks = self.fetch_chunk_sizes(body_producer=body_producer) 916 # HTTP chunk boundaries translate to application-visible breaks 917 self.assertEqual([16, 4, 16, 14], chunks) 918 919 def test_chunked_compressed(self): 920 compressed = self.compress(self.BODY) 921 self.assertGreater(len(compressed), 20) 922 923 def body_producer(write): 924 write(compressed[:20]) 925 write(compressed[20:]) 926 self.fetch_chunk_sizes(body_producer=body_producer, 927 headers={'Content-Encoding': 'gzip'}) 928 929 930class MaxHeaderSizeTest(AsyncHTTPTestCase): 931 def get_app(self): 932 return Application([('/', HelloWorldRequestHandler)]) 933 934 def get_httpserver_options(self): 935 return dict(max_header_size=1024) 936 937 def test_small_headers(self): 938 response = self.fetch("/", headers={'X-Filler': 'a' * 100}) 939 response.rethrow() 940 self.assertEqual(response.body, b"Hello world") 941 942 def test_large_headers(self): 943 with ExpectLog(gen_log, "Unsatisfiable read", required=False): 944 response = self.fetch("/", headers={'X-Filler': 'a' * 1000}) 945 # 431 is "Request Header Fields Too Large", defined in RFC 946 # 6585. However, many implementations just close the 947 # connection in this case, resulting in a 599. 948 self.assertIn(response.code, (431, 599)) 949 950 951@skipOnTravis 952class IdleTimeoutTest(AsyncHTTPTestCase): 953 def get_app(self): 954 return Application([('/', HelloWorldRequestHandler)]) 955 956 def get_httpserver_options(self): 957 return dict(idle_connection_timeout=0.1) 958 959 def setUp(self): 960 super(IdleTimeoutTest, self).setUp() 961 self.streams = [] 962 963 def tearDown(self): 964 super(IdleTimeoutTest, self).tearDown() 965 for stream in self.streams: 966 stream.close() 967 968 def connect(self): 969 stream = IOStream(socket.socket()) 970 stream.connect(('127.0.0.1', self.get_http_port()), self.stop) 971 self.wait() 972 self.streams.append(stream) 973 return stream 974 975 def test_unused_connection(self): 976 stream = self.connect() 977 stream.set_close_callback(self.stop) 978 self.wait() 979 980 def test_idle_after_use(self): 981 stream = self.connect() 982 stream.set_close_callback(lambda: self.stop("closed")) 983 984 # Use the connection twice to make sure keep-alives are working 985 for i in range(2): 986 stream.write(b"GET / HTTP/1.1\r\n\r\n") 987 stream.read_until(b"\r\n\r\n", self.stop) 988 self.wait() 989 stream.read_bytes(11, self.stop) 990 data = self.wait() 991 self.assertEqual(data, b"Hello world") 992 993 # Now let the timeout trigger and close the connection. 994 data = self.wait() 995 self.assertEqual(data, "closed") 996 997 998class BodyLimitsTest(AsyncHTTPTestCase): 999 def get_app(self): 1000 class BufferedHandler(RequestHandler): 1001 def put(self): 1002 self.write(str(len(self.request.body))) 1003 1004 @stream_request_body 1005 class StreamingHandler(RequestHandler): 1006 def initialize(self): 1007 self.bytes_read = 0 1008 1009 def prepare(self): 1010 if 'expected_size' in self.request.arguments: 1011 self.request.connection.set_max_body_size( 1012 int(self.get_argument('expected_size'))) 1013 if 'body_timeout' in self.request.arguments: 1014 self.request.connection.set_body_timeout( 1015 float(self.get_argument('body_timeout'))) 1016 1017 def data_received(self, data): 1018 self.bytes_read += len(data) 1019 1020 def put(self): 1021 self.write(str(self.bytes_read)) 1022 1023 return Application([('/buffered', BufferedHandler), 1024 ('/streaming', StreamingHandler)]) 1025 1026 def get_httpserver_options(self): 1027 return dict(body_timeout=3600, max_body_size=4096) 1028 1029 def get_http_client(self): 1030 # body_producer doesn't work on curl_httpclient, so override the 1031 # configured AsyncHTTPClient implementation. 1032 return SimpleAsyncHTTPClient(io_loop=self.io_loop) 1033 1034 def test_small_body(self): 1035 response = self.fetch('/buffered', method='PUT', body=b'a' * 4096) 1036 self.assertEqual(response.body, b'4096') 1037 response = self.fetch('/streaming', method='PUT', body=b'a' * 4096) 1038 self.assertEqual(response.body, b'4096') 1039 1040 def test_large_body_buffered(self): 1041 with ExpectLog(gen_log, '.*Content-Length too long'): 1042 response = self.fetch('/buffered', method='PUT', body=b'a' * 10240) 1043 self.assertEqual(response.code, 599) 1044 1045 def test_large_body_buffered_chunked(self): 1046 with ExpectLog(gen_log, '.*chunked body too large'): 1047 response = self.fetch('/buffered', method='PUT', 1048 body_producer=lambda write: write(b'a' * 10240)) 1049 self.assertEqual(response.code, 599) 1050 1051 def test_large_body_streaming(self): 1052 with ExpectLog(gen_log, '.*Content-Length too long'): 1053 response = self.fetch('/streaming', method='PUT', body=b'a' * 10240) 1054 self.assertEqual(response.code, 599) 1055 1056 def test_large_body_streaming_chunked(self): 1057 with ExpectLog(gen_log, '.*chunked body too large'): 1058 response = self.fetch('/streaming', method='PUT', 1059 body_producer=lambda write: write(b'a' * 10240)) 1060 self.assertEqual(response.code, 599) 1061 1062 def test_large_body_streaming_override(self): 1063 response = self.fetch('/streaming?expected_size=10240', method='PUT', 1064 body=b'a' * 10240) 1065 self.assertEqual(response.body, b'10240') 1066 1067 def test_large_body_streaming_chunked_override(self): 1068 response = self.fetch('/streaming?expected_size=10240', method='PUT', 1069 body_producer=lambda write: write(b'a' * 10240)) 1070 self.assertEqual(response.body, b'10240') 1071 1072 @gen_test 1073 def test_timeout(self): 1074 stream = IOStream(socket.socket()) 1075 try: 1076 yield stream.connect(('127.0.0.1', self.get_http_port())) 1077 # Use a raw stream because AsyncHTTPClient won't let us read a 1078 # response without finishing a body. 1079 stream.write(b'PUT /streaming?body_timeout=0.1 HTTP/1.0\r\n' 1080 b'Content-Length: 42\r\n\r\n') 1081 with ExpectLog(gen_log, 'Timeout reading body'): 1082 response = yield stream.read_until_close() 1083 self.assertEqual(response, b'') 1084 finally: 1085 stream.close() 1086 1087 @gen_test 1088 def test_body_size_override_reset(self): 1089 # The max_body_size override is reset between requests. 1090 stream = IOStream(socket.socket()) 1091 try: 1092 yield stream.connect(('127.0.0.1', self.get_http_port())) 1093 # Use a raw stream so we can make sure it's all on one connection. 1094 stream.write(b'PUT /streaming?expected_size=10240 HTTP/1.1\r\n' 1095 b'Content-Length: 10240\r\n\r\n') 1096 stream.write(b'a' * 10240) 1097 headers, response = yield gen.Task(read_stream_body, stream) 1098 self.assertEqual(response, b'10240') 1099 # Without the ?expected_size parameter, we get the old default value 1100 stream.write(b'PUT /streaming HTTP/1.1\r\n' 1101 b'Content-Length: 10240\r\n\r\n') 1102 with ExpectLog(gen_log, '.*Content-Length too long'): 1103 data = yield stream.read_until_close() 1104 self.assertEqual(data, b'') 1105 finally: 1106 stream.close() 1107 1108 1109class LegacyInterfaceTest(AsyncHTTPTestCase): 1110 def get_app(self): 1111 # The old request_callback interface does not implement the 1112 # delegate interface, and writes its response via request.write 1113 # instead of request.connection.write_headers. 1114 def handle_request(request): 1115 self.http1 = request.version.startswith("HTTP/1.") 1116 if not self.http1: 1117 # This test will be skipped if we're using HTTP/2, 1118 # so just close it out cleanly using the modern interface. 1119 request.connection.write_headers( 1120 ResponseStartLine('', 200, 'OK'), 1121 HTTPHeaders()) 1122 request.connection.finish() 1123 return 1124 message = b"Hello world" 1125 request.write(utf8("HTTP/1.1 200 OK\r\n" 1126 "Content-Length: %d\r\n\r\n" % len(message))) 1127 request.write(message) 1128 request.finish() 1129 return handle_request 1130 1131 def test_legacy_interface(self): 1132 response = self.fetch('/') 1133 if not self.http1: 1134 self.skipTest("requires HTTP/1.x") 1135 self.assertEqual(response.body, b"Hello world") 1136