1# coding: utf-8
2from __future__ import absolute_import, division, print_function
3
4from hashlib import md5
5
6from tornado.escape import utf8
7from tornado.httpclient import HTTPRequest
8from tornado.stack_context import ExceptionStackContext
9from tornado.testing import AsyncHTTPTestCase
10from tornado.test import httpclient_test
11from tornado.test.util import unittest
12from tornado.web import Application, RequestHandler
13
14
15try:
16    import pycurl  # type: ignore
17except ImportError:
18    pycurl = None
19
20if pycurl is not None:
21    from tornado.curl_httpclient import CurlAsyncHTTPClient
22
23
24@unittest.skipIf(pycurl is None, "pycurl module not present")
25class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
26    def get_http_client(self):
27        client = CurlAsyncHTTPClient(io_loop=self.io_loop,
28                                     defaults=dict(allow_ipv6=False))
29        # make sure AsyncHTTPClient magic doesn't give us the wrong class
30        self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
31        return client
32
33
34class DigestAuthHandler(RequestHandler):
35    def get(self):
36        realm = 'test'
37        opaque = 'asdf'
38        # Real implementations would use a random nonce.
39        nonce = "1234"
40        username = 'foo'
41        password = 'bar'
42
43        auth_header = self.request.headers.get('Authorization', None)
44        if auth_header is not None:
45            auth_mode, params = auth_header.split(' ', 1)
46            assert auth_mode == 'Digest'
47            param_dict = {}
48            for pair in params.split(','):
49                k, v = pair.strip().split('=', 1)
50                if v[0] == '"' and v[-1] == '"':
51                    v = v[1:-1]
52                param_dict[k] = v
53            assert param_dict['realm'] == realm
54            assert param_dict['opaque'] == opaque
55            assert param_dict['nonce'] == nonce
56            assert param_dict['username'] == username
57            assert param_dict['uri'] == self.request.path
58            h1 = md5(utf8('%s:%s:%s' % (username, realm, password))).hexdigest()
59            h2 = md5(utf8('%s:%s' % (self.request.method,
60                                     self.request.path))).hexdigest()
61            digest = md5(utf8('%s:%s:%s' % (h1, nonce, h2))).hexdigest()
62            if digest == param_dict['response']:
63                self.write('ok')
64            else:
65                self.write('fail')
66        else:
67            self.set_status(401)
68            self.set_header('WWW-Authenticate',
69                            'Digest realm="%s", nonce="%s", opaque="%s"' %
70                            (realm, nonce, opaque))
71
72
73class CustomReasonHandler(RequestHandler):
74    def get(self):
75        self.set_status(200, "Custom reason")
76
77
78class CustomFailReasonHandler(RequestHandler):
79    def get(self):
80        self.set_status(400, "Custom reason")
81
82
83@unittest.skipIf(pycurl is None, "pycurl module not present")
84class CurlHTTPClientTestCase(AsyncHTTPTestCase):
85    def setUp(self):
86        super(CurlHTTPClientTestCase, self).setUp()
87        self.http_client = self.create_client()
88
89    def get_app(self):
90        return Application([
91            ('/digest', DigestAuthHandler),
92            ('/custom_reason', CustomReasonHandler),
93            ('/custom_fail_reason', CustomFailReasonHandler),
94        ])
95
96    def create_client(self, **kwargs):
97        return CurlAsyncHTTPClient(self.io_loop, force_instance=True,
98                                   defaults=dict(allow_ipv6=False),
99                                   **kwargs)
100
101    def test_prepare_curl_callback_stack_context(self):
102        exc_info = []
103
104        def error_handler(typ, value, tb):
105            exc_info.append((typ, value, tb))
106            self.stop()
107            return True
108
109        with ExceptionStackContext(error_handler):
110            request = HTTPRequest(self.get_url('/'),
111                                  prepare_curl_callback=lambda curl: 1 / 0)
112        self.http_client.fetch(request, callback=self.stop)
113        self.wait()
114        self.assertEqual(1, len(exc_info))
115        self.assertIs(exc_info[0][0], ZeroDivisionError)
116
117    def test_digest_auth(self):
118        response = self.fetch('/digest', auth_mode='digest',
119                              auth_username='foo', auth_password='bar')
120        self.assertEqual(response.body, b'ok')
121
122    def test_custom_reason(self):
123        response = self.fetch('/custom_reason')
124        self.assertEqual(response.reason, "Custom reason")
125
126    def test_fail_custom_reason(self):
127        response = self.fetch('/custom_fail_reason')
128        self.assertEqual(str(response.error), "HTTP 400: Custom reason")
129
130    def test_failed_setup(self):
131        self.http_client = self.create_client(max_clients=1)
132        for i in range(5):
133            response = self.fetch(u'/ユニコード')
134            self.assertIsNot(response.error, None)
135