1#!/usr/bin/env python
2#
3# Copyright 2014 Facebook
4#
5# Licensed under the Apache License, Version 2.0 (the "License"); you may
6# not use this file except in compliance with the License. You may obtain
7# a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14# License for the specific language governing permissions and limitations
15# under the License.
16
17"""A non-blocking TCP connection factory.
18"""
19from __future__ import absolute_import, division, print_function
20
21import functools
22import socket
23
24from tornado.concurrent import Future
25from tornado.ioloop import IOLoop
26from tornado.iostream import IOStream
27from tornado import gen
28from tornado.netutil import Resolver
29from tornado.platform.auto import set_close_exec
30
31_INITIAL_CONNECT_TIMEOUT = 0.3
32
33
34class _Connector(object):
35    """A stateless implementation of the "Happy Eyeballs" algorithm.
36
37    "Happy Eyeballs" is documented in RFC6555 as the recommended practice
38    for when both IPv4 and IPv6 addresses are available.
39
40    In this implementation, we partition the addresses by family, and
41    make the first connection attempt to whichever address was
42    returned first by ``getaddrinfo``.  If that connection fails or
43    times out, we begin a connection in parallel to the first address
44    of the other family.  If there are additional failures we retry
45    with other addresses, keeping one connection attempt per family
46    in flight at a time.
47
48    http://tools.ietf.org/html/rfc6555
49
50    """
51    def __init__(self, addrinfo, io_loop, connect):
52        self.io_loop = io_loop
53        self.connect = connect
54
55        self.future = Future()
56        self.timeout = None
57        self.last_error = None
58        self.remaining = len(addrinfo)
59        self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
60
61    @staticmethod
62    def split(addrinfo):
63        """Partition the ``addrinfo`` list by address family.
64
65        Returns two lists.  The first list contains the first entry from
66        ``addrinfo`` and all others with the same family, and the
67        second list contains all other addresses (normally one list will
68        be AF_INET and the other AF_INET6, although non-standard resolvers
69        may return additional families).
70        """
71        primary = []
72        secondary = []
73        primary_af = addrinfo[0][0]
74        for af, addr in addrinfo:
75            if af == primary_af:
76                primary.append((af, addr))
77            else:
78                secondary.append((af, addr))
79        return primary, secondary
80
81    def start(self, timeout=_INITIAL_CONNECT_TIMEOUT):
82        self.try_connect(iter(self.primary_addrs))
83        self.set_timout(timeout)
84        return self.future
85
86    def try_connect(self, addrs):
87        try:
88            af, addr = next(addrs)
89        except StopIteration:
90            # We've reached the end of our queue, but the other queue
91            # might still be working.  Send a final error on the future
92            # only when both queues are finished.
93            if self.remaining == 0 and not self.future.done():
94                self.future.set_exception(self.last_error or
95                                          IOError("connection failed"))
96            return
97        future = self.connect(af, addr)
98        future.add_done_callback(functools.partial(self.on_connect_done,
99                                                   addrs, af, addr))
100
101    def on_connect_done(self, addrs, af, addr, future):
102        self.remaining -= 1
103        try:
104            stream = future.result()
105        except Exception as e:
106            if self.future.done():
107                return
108            # Error: try again (but remember what happened so we have an
109            # error to raise in the end)
110            self.last_error = e
111            self.try_connect(addrs)
112            if self.timeout is not None:
113                # If the first attempt failed, don't wait for the
114                # timeout to try an address from the secondary queue.
115                self.io_loop.remove_timeout(self.timeout)
116                self.on_timeout()
117            return
118        self.clear_timeout()
119        if self.future.done():
120            # This is a late arrival; just drop it.
121            stream.close()
122        else:
123            self.future.set_result((af, addr, stream))
124
125    def set_timout(self, timeout):
126        self.timeout = self.io_loop.add_timeout(self.io_loop.time() + timeout,
127                                                self.on_timeout)
128
129    def on_timeout(self):
130        self.timeout = None
131        self.try_connect(iter(self.secondary_addrs))
132
133    def clear_timeout(self):
134        if self.timeout is not None:
135            self.io_loop.remove_timeout(self.timeout)
136
137
138class TCPClient(object):
139    """A non-blocking TCP connection factory.
140
141    .. versionchanged:: 4.1
142       The ``io_loop`` argument is deprecated.
143    """
144    def __init__(self, resolver=None, io_loop=None):
145        self.io_loop = io_loop or IOLoop.current()
146        if resolver is not None:
147            self.resolver = resolver
148            self._own_resolver = False
149        else:
150            self.resolver = Resolver(io_loop=io_loop)
151            self._own_resolver = True
152
153    def close(self):
154        if self._own_resolver:
155            self.resolver.close()
156
157    @gen.coroutine
158    def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
159                max_buffer_size=None, source_ip=None, source_port=None):
160        """Connect to the given host and port.
161
162        Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
163        ``ssl_options`` is not None).
164
165        Using the ``source_ip`` kwarg, one can specify the source
166        IP address to use when establishing the connection.
167        In case the user needs to resolve and
168        use a specific interface, it has to be handled outside
169        of Tornado as this depends very much on the platform.
170
171        Similarly, when the user requires a certain source port, it can
172        be specified using the ``source_port`` arg.
173
174        .. versionchanged:: 4.5
175           Added the ``source_ip`` and ``source_port`` arguments.
176        """
177        addrinfo = yield self.resolver.resolve(host, port, af)
178        connector = _Connector(
179            addrinfo, self.io_loop,
180            functools.partial(self._create_stream, max_buffer_size,
181                              source_ip=source_ip, source_port=source_port)
182        )
183        af, addr, stream = yield connector.start()
184        # TODO: For better performance we could cache the (af, addr)
185        # information here and re-use it on subsequent connections to
186        # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
187        if ssl_options is not None:
188            stream = yield stream.start_tls(False, ssl_options=ssl_options,
189                                            server_hostname=host)
190        raise gen.Return(stream)
191
192    def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
193                       source_port=None):
194        # Always connect in plaintext; we'll convert to ssl if necessary
195        # after one connection has completed.
196        source_port_bind = source_port if isinstance(source_port, int) else 0
197        source_ip_bind = source_ip
198        if source_port_bind and not source_ip:
199            # User required a specific port, but did not specify
200            # a certain source IP, will bind to the default loopback.
201            source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
202            # Trying to use the same address family as the requested af socket:
203            # - 127.0.0.1 for IPv4
204            # - ::1 for IPv6
205        socket_obj = socket.socket(af)
206        set_close_exec(socket_obj.fileno())
207        if source_port_bind or source_ip_bind:
208            # If the user requires binding also to a specific IP/port.
209            try:
210                socket_obj.bind((source_ip_bind, source_port_bind))
211            except socket.error:
212                socket_obj.close()
213                # Fail loudly if unable to use the IP/port.
214                raise
215        try:
216            stream = IOStream(socket_obj,
217                              io_loop=self.io_loop,
218                              max_buffer_size=max_buffer_size)
219        except socket.error as e:
220            fu = Future()
221            fu.set_exception(e)
222            return fu
223        else:
224            return stream.connect(addr)
225