1# coding: utf-8 2from __future__ import absolute_import, division, print_function 3 4from hashlib import md5 5 6from tornado.escape import utf8 7from tornado.httpclient import HTTPRequest 8from tornado.stack_context import ExceptionStackContext 9from tornado.testing import AsyncHTTPTestCase 10from tornado.test import httpclient_test 11from tornado.test.util import unittest 12from tornado.web import Application, RequestHandler 13 14 15try: 16 import pycurl # type: ignore 17except ImportError: 18 pycurl = None 19 20if pycurl is not None: 21 from tornado.curl_httpclient import CurlAsyncHTTPClient 22 23 24@unittest.skipIf(pycurl is None, "pycurl module not present") 25class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase): 26 def get_http_client(self): 27 client = CurlAsyncHTTPClient(io_loop=self.io_loop, 28 defaults=dict(allow_ipv6=False)) 29 # make sure AsyncHTTPClient magic doesn't give us the wrong class 30 self.assertTrue(isinstance(client, CurlAsyncHTTPClient)) 31 return client 32 33 34class DigestAuthHandler(RequestHandler): 35 def get(self): 36 realm = 'test' 37 opaque = 'asdf' 38 # Real implementations would use a random nonce. 39 nonce = "1234" 40 username = 'foo' 41 password = 'bar' 42 43 auth_header = self.request.headers.get('Authorization', None) 44 if auth_header is not None: 45 auth_mode, params = auth_header.split(' ', 1) 46 assert auth_mode == 'Digest' 47 param_dict = {} 48 for pair in params.split(','): 49 k, v = pair.strip().split('=', 1) 50 if v[0] == '"' and v[-1] == '"': 51 v = v[1:-1] 52 param_dict[k] = v 53 assert param_dict['realm'] == realm 54 assert param_dict['opaque'] == opaque 55 assert param_dict['nonce'] == nonce 56 assert param_dict['username'] == username 57 assert param_dict['uri'] == self.request.path 58 h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest() 59 h2 = md5(utf8('%s:%s' % (self.request.method, 60 self.request.path))).hexdigest() 61 digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest() 62 if digest == param_dict['response']: 63 self.write('ok') 64 else: 65 self.write('fail') 66 else: 67 self.set_status(401) 68 self.set_header('WWW-Authenticate', 69 'Digest realm="%s", nonce="%s", opaque="%s"' % 70 (realm, nonce, opaque)) 71 72 73class CustomReasonHandler(RequestHandler): 74 def get(self): 75 self.set_status(200, "Custom reason") 76 77 78class CustomFailReasonHandler(RequestHandler): 79 def get(self): 80 self.set_status(400, "Custom reason") 81 82 83@unittest.skipIf(pycurl is None, "pycurl module not present") 84class CurlHTTPClientTestCase(AsyncHTTPTestCase): 85 def setUp(self): 86 super(CurlHTTPClientTestCase, self).setUp() 87 self.http_client = self.create_client() 88 89 def get_app(self): 90 return Application([ 91 ('/digest', DigestAuthHandler), 92 ('/custom_reason', CustomReasonHandler), 93 ('/custom_fail_reason', CustomFailReasonHandler), 94 ]) 95 96 def create_client(self, **kwargs): 97 return CurlAsyncHTTPClient(self.io_loop, force_instance=True, 98 defaults=dict(allow_ipv6=False), 99 **kwargs) 100 101 def test_prepare_curl_callback_stack_context(self): 102 exc_info = [] 103 104 def error_handler(typ, value, tb): 105 exc_info.append((typ, value, tb)) 106 self.stop() 107 return True 108 109 with ExceptionStackContext(error_handler): 110 request = HTTPRequest(self.get_url('/'), 111 prepare_curl_callback=lambda curl: 1 / 0) 112 self.http_client.fetch(request, callback=self.stop) 113 self.wait() 114 self.assertEqual(1, len(exc_info)) 115 self.assertIs(exc_info[0][0], ZeroDivisionError) 116 117 def test_digest_auth(self): 118 response = self.fetch('/digest', auth_mode='digest', 119 auth_username='foo', auth_password='bar') 120 self.assertEqual(response.body, b'ok') 121 122 def test_custom_reason(self): 123 response = self.fetch('/custom_reason') 124 self.assertEqual(response.reason, "Custom reason") 125 126 def test_fail_custom_reason(self): 127 response = self.fetch('/custom_fail_reason') 128 self.assertEqual(str(response.error), "HTTP 400: Custom reason") 129 130 def test_failed_setup(self): 131 self.http_client = self.create_client(max_clients=1) 132 for i in range(5): 133 response = self.fetch(u'/ユニコード') 134 self.assertIsNot(response.error, None) 135