1#!/usr/bin/env python
2from __future__ import absolute_import, division, print_function
3
4from tornado.escape import utf8, _unicode
5from tornado import gen
6from tornado.httpclient import HTTPResponse, HTTPError, AsyncHTTPClient, main, _RequestProxy
7from tornado import httputil
8from tornado.http1connection import HTTP1Connection, HTTP1ConnectionParameters
9from tornado.iostream import StreamClosedError
10from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
11from tornado.log import gen_log
12from tornado import stack_context
13from tornado.tcpclient import TCPClient
14from tornado.util import PY3
15
16import base64
17import collections
18import copy
19import functools
20import re
21import socket
22import sys
23from io import BytesIO
24
25
26if PY3:
27    import urllib.parse as urlparse
28else:
29    import urlparse
30
31try:
32    import ssl
33except ImportError:
34    # ssl is not available on Google App Engine.
35    ssl = None
36
37try:
38    import certifi
39except ImportError:
40    certifi = None
41
42
43def _default_ca_certs():
44    if certifi is None:
45        raise Exception("The 'certifi' package is required to use https "
46                        "in simple_httpclient")
47    return certifi.where()
48
49
50class SimpleAsyncHTTPClient(AsyncHTTPClient):
51    """Non-blocking HTTP client with no external dependencies.
52
53    This class implements an HTTP 1.1 client on top of Tornado's IOStreams.
54    Some features found in the curl-based AsyncHTTPClient are not yet
55    supported.  In particular, proxies are not supported, connections
56    are not reused, and callers cannot select the network interface to be
57    used.
58    """
59    def initialize(self, io_loop, max_clients=10,
60                   hostname_mapping=None, max_buffer_size=104857600,
61                   resolver=None, defaults=None, max_header_size=None,
62                   max_body_size=None):
63        """Creates a AsyncHTTPClient.
64
65        Only a single AsyncHTTPClient instance exists per IOLoop
66        in order to provide limitations on the number of pending connections.
67        ``force_instance=True`` may be used to suppress this behavior.
68
69        Note that because of this implicit reuse, unless ``force_instance``
70        is used, only the first call to the constructor actually uses
71        its arguments. It is recommended to use the ``configure`` method
72        instead of the constructor to ensure that arguments take effect.
73
74        ``max_clients`` is the number of concurrent requests that can be
75        in progress; when this limit is reached additional requests will be
76        queued. Note that time spent waiting in this queue still counts
77        against the ``request_timeout``.
78
79        ``hostname_mapping`` is a dictionary mapping hostnames to IP addresses.
80        It can be used to make local DNS changes when modifying system-wide
81        settings like ``/etc/hosts`` is not possible or desirable (e.g. in
82        unittests).
83
84        ``max_buffer_size`` (default 100MB) is the number of bytes
85        that can be read into memory at once. ``max_body_size``
86        (defaults to ``max_buffer_size``) is the largest response body
87        that the client will accept.  Without a
88        ``streaming_callback``, the smaller of these two limits
89        applies; with a ``streaming_callback`` only ``max_body_size``
90        does.
91
92        .. versionchanged:: 4.2
93           Added the ``max_body_size`` argument.
94        """
95        super(SimpleAsyncHTTPClient, self).initialize(io_loop,
96                                                      defaults=defaults)
97        self.max_clients = max_clients
98        self.queue = collections.deque()
99        self.active = {}
100        self.waiting = {}
101        self.max_buffer_size = max_buffer_size
102        self.max_header_size = max_header_size
103        self.max_body_size = max_body_size
104        # TCPClient could create a Resolver for us, but we have to do it
105        # ourselves to support hostname_mapping.
106        if resolver:
107            self.resolver = resolver
108            self.own_resolver = False
109        else:
110            self.resolver = Resolver(io_loop=io_loop)
111            self.own_resolver = True
112        if hostname_mapping is not None:
113            self.resolver = OverrideResolver(resolver=self.resolver,
114                                             mapping=hostname_mapping)
115        self.tcp_client = TCPClient(resolver=self.resolver, io_loop=io_loop)
116
117    def close(self):
118        super(SimpleAsyncHTTPClient, self).close()
119        if self.own_resolver:
120            self.resolver.close()
121        self.tcp_client.close()
122
123    def fetch_impl(self, request, callback):
124        key = object()
125        self.queue.append((key, request, callback))
126        if not len(self.active) < self.max_clients:
127            timeout_handle = self.io_loop.add_timeout(
128                self.io_loop.time() + min(request.connect_timeout,
129                                          request.request_timeout),
130                functools.partial(self._on_timeout, key, "in request queue"))
131        else:
132            timeout_handle = None
133        self.waiting[key] = (request, callback, timeout_handle)
134        self._process_queue()
135        if self.queue:
136            gen_log.debug("max_clients limit reached, request queued. "
137                          "%d active, %d queued requests." % (
138                              len(self.active), len(self.queue)))
139
140    def _process_queue(self):
141        with stack_context.NullContext():
142            while self.queue and len(self.active) < self.max_clients:
143                key, request, callback = self.queue.popleft()
144                if key not in self.waiting:
145                    continue
146                self._remove_timeout(key)
147                self.active[key] = (request, callback)
148                release_callback = functools.partial(self._release_fetch, key)
149                self._handle_request(request, release_callback, callback)
150
151    def _connection_class(self):
152        return _HTTPConnection
153
154    def _handle_request(self, request, release_callback, final_callback):
155        self._connection_class()(
156            self.io_loop, self, request, release_callback,
157            final_callback, self.max_buffer_size, self.tcp_client,
158            self.max_header_size, self.max_body_size)
159
160    def _release_fetch(self, key):
161        del self.active[key]
162        self._process_queue()
163
164    def _remove_timeout(self, key):
165        if key in self.waiting:
166            request, callback, timeout_handle = self.waiting[key]
167            if timeout_handle is not None:
168                self.io_loop.remove_timeout(timeout_handle)
169            del self.waiting[key]
170
171    def _on_timeout(self, key, info=None):
172        """Timeout callback of request.
173
174        Construct a timeout HTTPResponse when a timeout occurs.
175
176        :arg object key: A simple object to mark the request.
177        :info string key: More detailed timeout information.
178        """
179        request, callback, timeout_handle = self.waiting[key]
180        self.queue.remove((key, request, callback))
181
182        error_message = "Timeout {0}".format(info) if info else "Timeout"
183        timeout_response = HTTPResponse(
184            request, 599, error=HTTPError(599, error_message),
185            request_time=self.io_loop.time() - request.start_time)
186        self.io_loop.add_callback(callback, timeout_response)
187        del self.waiting[key]
188
189
190class _HTTPConnection(httputil.HTTPMessageDelegate):
191    _SUPPORTED_METHODS = set(["GET", "HEAD", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"])
192
193    def __init__(self, io_loop, client, request, release_callback,
194                 final_callback, max_buffer_size, tcp_client,
195                 max_header_size, max_body_size):
196        self.start_time = io_loop.time()
197        self.io_loop = io_loop
198        self.client = client
199        self.request = request
200        self.release_callback = release_callback
201        self.final_callback = final_callback
202        self.max_buffer_size = max_buffer_size
203        self.tcp_client = tcp_client
204        self.max_header_size = max_header_size
205        self.max_body_size = max_body_size
206        self.code = None
207        self.headers = None
208        self.chunks = []
209        self._decompressor = None
210        # Timeout handle returned by IOLoop.add_timeout
211        self._timeout = None
212        self._sockaddr = None
213        with stack_context.ExceptionStackContext(self._handle_exception):
214            self.parsed = urlparse.urlsplit(_unicode(self.request.url))
215            if self.parsed.scheme not in ("http", "https"):
216                raise ValueError("Unsupported url scheme: %s" %
217                                 self.request.url)
218            # urlsplit results have hostname and port results, but they
219            # didn't support ipv6 literals until python 2.7.
220            netloc = self.parsed.netloc
221            if "@" in netloc:
222                userpass, _, netloc = netloc.rpartition("@")
223            host, port = httputil.split_host_and_port(netloc)
224            if port is None:
225                port = 443 if self.parsed.scheme == "https" else 80
226            if re.match(r'^\[.*\]$', host):
227                # raw ipv6 addresses in urls are enclosed in brackets
228                host = host[1:-1]
229            self.parsed_hostname = host  # save final host for _on_connect
230
231            if request.allow_ipv6 is False:
232                af = socket.AF_INET
233            else:
234                af = socket.AF_UNSPEC
235
236            ssl_options = self._get_ssl_options(self.parsed.scheme)
237
238            timeout = min(self.request.connect_timeout, self.request.request_timeout)
239            if timeout:
240                self._timeout = self.io_loop.add_timeout(
241                    self.start_time + timeout,
242                    stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
243            self.tcp_client.connect(host, port, af=af,
244                                    ssl_options=ssl_options,
245                                    max_buffer_size=self.max_buffer_size,
246                                    callback=self._on_connect)
247
248    def _get_ssl_options(self, scheme):
249        if scheme == "https":
250            if self.request.ssl_options is not None:
251                return self.request.ssl_options
252            # If we are using the defaults, don't construct a
253            # new SSLContext.
254            if (self.request.validate_cert and
255                    self.request.ca_certs is None and
256                    self.request.client_cert is None and
257                    self.request.client_key is None):
258                return _client_ssl_defaults
259            ssl_options = {}
260            if self.request.validate_cert:
261                ssl_options["cert_reqs"] = ssl.CERT_REQUIRED
262            if self.request.ca_certs is not None:
263                ssl_options["ca_certs"] = self.request.ca_certs
264            elif not hasattr(ssl, 'create_default_context'):
265                # When create_default_context is present,
266                # we can omit the "ca_certs" parameter entirely,
267                # which avoids the dependency on "certifi" for py34.
268                ssl_options["ca_certs"] = _default_ca_certs()
269            if self.request.client_key is not None:
270                ssl_options["keyfile"] = self.request.client_key
271            if self.request.client_cert is not None:
272                ssl_options["certfile"] = self.request.client_cert
273
274            # SSL interoperability is tricky.  We want to disable
275            # SSLv2 for security reasons; it wasn't disabled by default
276            # until openssl 1.0.  The best way to do this is to use
277            # the SSL_OP_NO_SSLv2, but that wasn't exposed to python
278            # until 3.2.  Python 2.7 adds the ciphers argument, which
279            # can also be used to disable SSLv2.  As a last resort
280            # on python 2.6, we set ssl_version to TLSv1.  This is
281            # more narrow than we'd like since it also breaks
282            # compatibility with servers configured for SSLv3 only,
283            # but nearly all servers support both SSLv3 and TLSv1:
284            # http://blog.ivanristic.com/2011/09/ssl-survey-protocol-support.html
285            if sys.version_info >= (2, 7):
286                # In addition to disabling SSLv2, we also exclude certain
287                # classes of insecure ciphers.
288                ssl_options["ciphers"] = "DEFAULT:!SSLv2:!EXPORT:!DES"
289            else:
290                # This is really only necessary for pre-1.0 versions
291                # of openssl, but python 2.6 doesn't expose version
292                # information.
293                ssl_options["ssl_version"] = ssl.PROTOCOL_TLSv1
294            return ssl_options
295        return None
296
297    def _on_timeout(self, info=None):
298        """Timeout callback of _HTTPConnection instance.
299
300        Raise a timeout HTTPError when a timeout occurs.
301
302        :info string key: More detailed timeout information.
303        """
304        self._timeout = None
305        error_message = "Timeout {0}".format(info) if info else "Timeout"
306        if self.final_callback is not None:
307            raise HTTPError(599, error_message)
308
309    def _remove_timeout(self):
310        if self._timeout is not None:
311            self.io_loop.remove_timeout(self._timeout)
312            self._timeout = None
313
314    def _on_connect(self, stream):
315        if self.final_callback is None:
316            # final_callback is cleared if we've hit our timeout.
317            stream.close()
318            return
319        self.stream = stream
320        self.stream.set_close_callback(self.on_connection_close)
321        self._remove_timeout()
322        if self.final_callback is None:
323            return
324        if self.request.request_timeout:
325            self._timeout = self.io_loop.add_timeout(
326                self.start_time + self.request.request_timeout,
327                stack_context.wrap(functools.partial(self._on_timeout, "during request")))
328        if (self.request.method not in self._SUPPORTED_METHODS and
329                not self.request.allow_nonstandard_methods):
330            raise KeyError("unknown method %s" % self.request.method)
331        for key in ('network_interface',
332                    'proxy_host', 'proxy_port',
333                    'proxy_username', 'proxy_password',
334                    'proxy_auth_mode'):
335            if getattr(self.request, key, None):
336                raise NotImplementedError('%s not supported' % key)
337        if "Connection" not in self.request.headers:
338            self.request.headers["Connection"] = "close"
339        if "Host" not in self.request.headers:
340            if '@' in self.parsed.netloc:
341                self.request.headers["Host"] = self.parsed.netloc.rpartition('@')[-1]
342            else:
343                self.request.headers["Host"] = self.parsed.netloc
344        username, password = None, None
345        if self.parsed.username is not None:
346            username, password = self.parsed.username, self.parsed.password
347        elif self.request.auth_username is not None:
348            username = self.request.auth_username
349            password = self.request.auth_password or ''
350        if username is not None:
351            if self.request.auth_mode not in (None, "basic"):
352                raise ValueError("unsupported auth_mode %s",
353                                 self.request.auth_mode)
354            auth = utf8(username) + b":" + utf8(password)
355            self.request.headers["Authorization"] = (b"Basic " +
356                                                     base64.b64encode(auth))
357        if self.request.user_agent:
358            self.request.headers["User-Agent"] = self.request.user_agent
359        if not self.request.allow_nonstandard_methods:
360            # Some HTTP methods nearly always have bodies while others
361            # almost never do. Fail in this case unless the user has
362            # opted out of sanity checks with allow_nonstandard_methods.
363            body_expected = self.request.method in ("POST", "PATCH", "PUT")
364            body_present = (self.request.body is not None or
365                            self.request.body_producer is not None)
366            if ((body_expected and not body_present) or
367                    (body_present and not body_expected)):
368                raise ValueError(
369                    'Body must %sbe None for method %s (unless '
370                    'allow_nonstandard_methods is true)' %
371                    ('not ' if body_expected else '', self.request.method))
372        if self.request.expect_100_continue:
373            self.request.headers["Expect"] = "100-continue"
374        if self.request.body is not None:
375            # When body_producer is used the caller is responsible for
376            # setting Content-Length (or else chunked encoding will be used).
377            self.request.headers["Content-Length"] = str(len(
378                self.request.body))
379        if (self.request.method == "POST" and
380                "Content-Type" not in self.request.headers):
381            self.request.headers["Content-Type"] = "application/x-www-form-urlencoded"
382        if self.request.decompress_response:
383            self.request.headers["Accept-Encoding"] = "gzip"
384        req_path = ((self.parsed.path or '/') +
385                    (('?' + self.parsed.query) if self.parsed.query else ''))
386        self.connection = self._create_connection(stream)
387        start_line = httputil.RequestStartLine(self.request.method,
388                                               req_path, '')
389        self.connection.write_headers(start_line, self.request.headers)
390        if self.request.expect_100_continue:
391            self._read_response()
392        else:
393            self._write_body(True)
394
395    def _create_connection(self, stream):
396        stream.set_nodelay(True)
397        connection = HTTP1Connection(
398            stream, True,
399            HTTP1ConnectionParameters(
400                no_keep_alive=True,
401                max_header_size=self.max_header_size,
402                max_body_size=self.max_body_size,
403                decompress=self.request.decompress_response),
404            self._sockaddr)
405        return connection
406
407    def _write_body(self, start_read):
408        if self.request.body is not None:
409            self.connection.write(self.request.body)
410        elif self.request.body_producer is not None:
411            fut = self.request.body_producer(self.connection.write)
412            if fut is not None:
413                fut = gen.convert_yielded(fut)
414
415                def on_body_written(fut):
416                    fut.result()
417                    self.connection.finish()
418                    if start_read:
419                        self._read_response()
420                self.io_loop.add_future(fut, on_body_written)
421                return
422        self.connection.finish()
423        if start_read:
424            self._read_response()
425
426    def _read_response(self):
427        # Ensure that any exception raised in read_response ends up in our
428        # stack context.
429        self.io_loop.add_future(
430            self.connection.read_response(self),
431            lambda f: f.result())
432
433    def _release(self):
434        if self.release_callback is not None:
435            release_callback = self.release_callback
436            self.release_callback = None
437            release_callback()
438
439    def _run_callback(self, response):
440        self._release()
441        if self.final_callback is not None:
442            final_callback = self.final_callback
443            self.final_callback = None
444            self.io_loop.add_callback(final_callback, response)
445
446    def _handle_exception(self, typ, value, tb):
447        if self.final_callback:
448            self._remove_timeout()
449            if isinstance(value, StreamClosedError):
450                if value.real_error is None:
451                    value = HTTPError(599, "Stream closed")
452                else:
453                    value = value.real_error
454            self._run_callback(HTTPResponse(self.request, 599, error=value,
455                                            request_time=self.io_loop.time() - self.start_time,
456                                            ))
457
458            if hasattr(self, "stream"):
459                # TODO: this may cause a StreamClosedError to be raised
460                # by the connection's Future.  Should we cancel the
461                # connection more gracefully?
462                self.stream.close()
463            return True
464        else:
465            # If our callback has already been called, we are probably
466            # catching an exception that is not caused by us but rather
467            # some child of our callback. Rather than drop it on the floor,
468            # pass it along, unless it's just the stream being closed.
469            return isinstance(value, StreamClosedError)
470
471    def on_connection_close(self):
472        if self.final_callback is not None:
473            message = "Connection closed"
474            if self.stream.error:
475                raise self.stream.error
476            try:
477                raise HTTPError(599, message)
478            except HTTPError:
479                self._handle_exception(*sys.exc_info())
480
481    def headers_received(self, first_line, headers):
482        if self.request.expect_100_continue and first_line.code == 100:
483            self._write_body(False)
484            return
485        self.code = first_line.code
486        self.reason = first_line.reason
487        self.headers = headers
488
489        if self._should_follow_redirect():
490            return
491
492        if self.request.header_callback is not None:
493            # Reassemble the start line.
494            self.request.header_callback('%s %s %s\r\n' % first_line)
495            for k, v in self.headers.get_all():
496                self.request.header_callback("%s: %s\r\n" % (k, v))
497            self.request.header_callback('\r\n')
498
499    def _should_follow_redirect(self):
500        return (self.request.follow_redirects and
501                self.request.max_redirects > 0 and
502                self.code in (301, 302, 303, 307, 308))
503
504    def finish(self):
505        data = b''.join(self.chunks)
506        self._remove_timeout()
507        original_request = getattr(self.request, "original_request",
508                                   self.request)
509        if self._should_follow_redirect():
510            assert isinstance(self.request, _RequestProxy)
511            new_request = copy.copy(self.request.request)
512            new_request.url = urlparse.urljoin(self.request.url,
513                                               self.headers["Location"])
514            new_request.max_redirects = self.request.max_redirects - 1
515            del new_request.headers["Host"]
516            # http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.4
517            # Client SHOULD make a GET request after a 303.
518            # According to the spec, 302 should be followed by the same
519            # method as the original request, but in practice browsers
520            # treat 302 the same as 303, and many servers use 302 for
521            # compatibility with pre-HTTP/1.1 user agents which don't
522            # understand the 303 status.
523            if self.code in (302, 303):
524                new_request.method = "GET"
525                new_request.body = None
526                for h in ["Content-Length", "Content-Type",
527                          "Content-Encoding", "Transfer-Encoding"]:
528                    try:
529                        del self.request.headers[h]
530                    except KeyError:
531                        pass
532            new_request.original_request = original_request
533            final_callback = self.final_callback
534            self.final_callback = None
535            self._release()
536            self.client.fetch(new_request, final_callback)
537            self._on_end_request()
538            return
539        if self.request.streaming_callback:
540            buffer = BytesIO()
541        else:
542            buffer = BytesIO(data)  # TODO: don't require one big string?
543        response = HTTPResponse(original_request,
544                                self.code, reason=getattr(self, 'reason', None),
545                                headers=self.headers,
546                                request_time=self.io_loop.time() - self.start_time,
547                                buffer=buffer,
548                                effective_url=self.request.url)
549        self._run_callback(response)
550        self._on_end_request()
551
552    def _on_end_request(self):
553        self.stream.close()
554
555    def data_received(self, chunk):
556        if self._should_follow_redirect():
557            # We're going to follow a redirect so just discard the body.
558            return
559        if self.request.streaming_callback is not None:
560            self.request.streaming_callback(chunk)
561        else:
562            self.chunks.append(chunk)
563
564
565if __name__ == "__main__":
566    AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
567    main()
568