1#!/usr/bin/env python 2# 3# Copyright 2014 Facebook 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); you may 6# not use this file except in compliance with the License. You may obtain 7# a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 14# License for the specific language governing permissions and limitations 15# under the License. 16 17"""A non-blocking TCP connection factory. 18""" 19from __future__ import absolute_import, division, print_function 20 21import functools 22import socket 23 24from tornado.concurrent import Future 25from tornado.ioloop import IOLoop 26from tornado.iostream import IOStream 27from tornado import gen 28from tornado.netutil import Resolver 29from tornado.platform.auto import set_close_exec 30 31_INITIAL_CONNECT_TIMEOUT = 0.3 32 33 34class _Connector(object): 35 """A stateless implementation of the "Happy Eyeballs" algorithm. 36 37 "Happy Eyeballs" is documented in RFC6555 as the recommended practice 38 for when both IPv4 and IPv6 addresses are available. 39 40 In this implementation, we partition the addresses by family, and 41 make the first connection attempt to whichever address was 42 returned first by ``getaddrinfo``. If that connection fails or 43 times out, we begin a connection in parallel to the first address 44 of the other family. If there are additional failures we retry 45 with other addresses, keeping one connection attempt per family 46 in flight at a time. 47 48 http://tools.ietf.org/html/rfc6555 49 50 """ 51 def __init__(self, addrinfo, io_loop, connect): 52 self.io_loop = io_loop 53 self.connect = connect 54 55 self.future = Future() 56 self.timeout = None 57 self.last_error = None 58 self.remaining = len(addrinfo) 59 self.primary_addrs, self.secondary_addrs = self.split(addrinfo) 60 61 @staticmethod 62 def split(addrinfo): 63 """Partition the ``addrinfo`` list by address family. 64 65 Returns two lists. The first list contains the first entry from 66 ``addrinfo`` and all others with the same family, and the 67 second list contains all other addresses (normally one list will 68 be AF_INET and the other AF_INET6, although non-standard resolvers 69 may return additional families). 70 """ 71 primary = [] 72 secondary = [] 73 primary_af = addrinfo[0][0] 74 for af, addr in addrinfo: 75 if af == primary_af: 76 primary.append((af, addr)) 77 else: 78 secondary.append((af, addr)) 79 return primary, secondary 80 81 def start(self, timeout=_INITIAL_CONNECT_TIMEOUT): 82 self.try_connect(iter(self.primary_addrs)) 83 self.set_timout(timeout) 84 return self.future 85 86 def try_connect(self, addrs): 87 try: 88 af, addr = next(addrs) 89 except StopIteration: 90 # We've reached the end of our queue, but the other queue 91 # might still be working. Send a final error on the future 92 # only when both queues are finished. 93 if self.remaining == 0 and not self.future.done(): 94 self.future.set_exception(self.last_error or 95 IOError("connection failed")) 96 return 97 future = self.connect(af, addr) 98 future.add_done_callback(functools.partial(self.on_connect_done, 99 addrs, af, addr)) 100 101 def on_connect_done(self, addrs, af, addr, future): 102 self.remaining -= 1 103 try: 104 stream = future.result() 105 except Exception as e: 106 if self.future.done(): 107 return 108 # Error: try again (but remember what happened so we have an 109 # error to raise in the end) 110 self.last_error = e 111 self.try_connect(addrs) 112 if self.timeout is not None: 113 # If the first attempt failed, don't wait for the 114 # timeout to try an address from the secondary queue. 115 self.io_loop.remove_timeout(self.timeout) 116 self.on_timeout() 117 return 118 self.clear_timeout() 119 if self.future.done(): 120 # This is a late arrival; just drop it. 121 stream.close() 122 else: 123 self.future.set_result((af, addr, stream)) 124 125 def set_timout(self, timeout): 126 self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout, 127 self.on_timeout) 128 129 def on_timeout(self): 130 self.timeout = None 131 self.try_connect(iter(self.secondary_addrs)) 132 133 def clear_timeout(self): 134 if self.timeout is not None: 135 self.io_loop.remove_timeout(self.timeout) 136 137 138class TCPClient(object): 139 """A non-blocking TCP connection factory. 140 141 .. versionchanged:: 4.1 142 The ``io_loop`` argument is deprecated. 143 """ 144 def __init__(self, resolver=None, io_loop=None): 145 self.io_loop = io_loop or IOLoop.current() 146 if resolver is not None: 147 self.resolver = resolver 148 self._own_resolver = False 149 else: 150 self.resolver = Resolver(io_loop=io_loop) 151 self._own_resolver = True 152 153 def close(self): 154 if self._own_resolver: 155 self.resolver.close() 156 157 @gen.coroutine 158 def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None, 159 max_buffer_size=None, source_ip=None, source_port=None): 160 """Connect to the given host and port. 161 162 Asynchronously returns an `.IOStream` (or `.SSLIOStream` if 163 ``ssl_options`` is not None). 164 165 Using the ``source_ip`` kwarg, one can specify the source 166 IP address to use when establishing the connection. 167 In case the user needs to resolve and 168 use a specific interface, it has to be handled outside 169 of Tornado as this depends very much on the platform. 170 171 Similarly, when the user requires a certain source port, it can 172 be specified using the ``source_port`` arg. 173 174 .. versionchanged:: 4.5 175 Added the ``source_ip`` and ``source_port`` arguments. 176 """ 177 addrinfo = yield self.resolver.resolve(host, port, af) 178 connector = _Connector( 179 addrinfo, self.io_loop, 180 functools.partial(self._create_stream, max_buffer_size, 181 source_ip=source_ip, source_port=source_port) 182 ) 183 af, addr, stream = yield connector.start() 184 # TODO: For better performance we could cache the (af, addr) 185 # information here and re-use it on subsequent connections to 186 # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2) 187 if ssl_options is not None: 188 stream = yield stream.start_tls(False, ssl_options=ssl_options, 189 server_hostname=host) 190 raise gen.Return(stream) 191 192 def _create_stream(self, max_buffer_size, af, addr, source_ip=None, 193 source_port=None): 194 # Always connect in plaintext; we'll convert to ssl if necessary 195 # after one connection has completed. 196 source_port_bind = source_port if isinstance(source_port, int) else 0 197 source_ip_bind = source_ip 198 if source_port_bind and not source_ip: 199 # User required a specific port, but did not specify 200 # a certain source IP, will bind to the default loopback. 201 source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1' 202 # Trying to use the same address family as the requested af socket: 203 # - 127.0.0.1 for IPv4 204 # - ::1 for IPv6 205 socket_obj = socket.socket(af) 206 set_close_exec(socket_obj.fileno()) 207 if source_port_bind or source_ip_bind: 208 # If the user requires binding also to a specific IP/port. 209 try: 210 socket_obj.bind((source_ip_bind, source_port_bind)) 211 except socket.error: 212 socket_obj.close() 213 # Fail loudly if unable to use the IP/port. 214 raise 215 try: 216 stream = IOStream(socket_obj, 217 io_loop=self.io_loop, 218 max_buffer_size=max_buffer_size) 219 except socket.error as e: 220 fu = Future() 221 fu.set_exception(e) 222 return fu 223 else: 224 return stream.connect(addr) 225