1# -*- coding: utf-8 -*-
2
3from __future__ import absolute_import, division
4
5import ssl
6import asyncio
7import errno
8import os
9import socket
10import struct
11import sys
12
13from thriftpy2.transport import TTransportException
14from thriftpy2.transport._ssl import (
15    create_thriftpy_context,
16    RESTRICTED_SERVER_CIPHERS,
17    DEFAULT_CIPHERS
18)
19
20
21class TAsyncSocket(object):
22    """Socket implementation for client side."""
23
24    def __init__(self, host=None, port=None, unix_socket=None,
25                 sock=None, socket_family=socket.AF_INET,
26                 socket_timeout=3000, connect_timeout=None,
27                 ssl_context=None, validate=True,
28                 cafile=None, capath=None, certfile=None, keyfile=None,
29                 ciphers=DEFAULT_CIPHERS):
30        """Initialize a TSocket
31
32        TSocket can be initialized in 3 ways:
33        * host + port. can configure to use AF_INET/AF_INET6
34        * unix_socket
35        * socket. should pass already opened socket here.
36
37        @param host(str)    The host to connect to.
38        @param port(int)    The (TCP) port to connect to.
39        @param unix_socket(str) The filename of a unix socket to connect to.
40        @param sock(socket)     Initialize with opened socket directly.
41            If this param used, the host, port and unix_socket params will
42            be ignored.
43        @param socket_family(str) socket.AF_INET or socket.AF_INET6. only
44            take effect when using host/port
45        @param socket_timeout   socket timeout in ms
46        @param connect_timeout  connect timeout in ms, only used in
47            connection, will be set to socket_timeout if not set.
48        @param validate(bool)       Set to False to disable SSL certificate
49            validation and hostname validation. Default enabled.
50        @param cafile(str)          Path to a file of concatenated CA
51            certificates in PEM format.
52        @param capath(str)           path to a directory containing several CA
53            certificates in PEM format, following an OpenSSL specific layout.
54        @param certfile(str)        The certfile string must be the path to a
55            single file in PEM format containing the certificate as well as
56            any number of CA certificates needed to establish the
57            certificate’s authenticity.
58        @param keyfile(str)         The keyfile string, if not present,
59            the private key will be taken from certfile as well.
60        @param ciphers(list<str>)   The cipher suites to allow
61        @param ssl_context(SSLContext)  Customize the SSLContext, can be used
62            to persist SSLContext object. Caution it's easy to get wrong, only
63            use if you know what you're doing.
64        """
65        if sock:
66            self.raw_sock = sock
67        elif unix_socket:
68            self.unix_socket = unix_socket
69            self.host = None
70            self.port = None
71            self.raw_sock = None
72            self.sock_factory = asyncio.open_unix_connection
73        else:
74            self.unix_socket = None
75            self.host = host
76            self.port = port
77            self.raw_sock = None
78            self.sock_factory = asyncio.open_connection
79
80        self.socket_family = socket_family
81        self.socket_timeout = socket_timeout / 1000 if socket_timeout else None
82        self.connect_timeout = connect_timeout / 1000 if connect_timeout \
83            else self.socket_timeout
84
85        if ssl_context:
86            self.ssl_context = ssl_context
87            self.server_hostname = host
88        elif certfile or keyfile:
89            self.server_hostname = host
90            self.ssl_context = create_thriftpy_context(server_side=False,
91                                                       ciphers=ciphers)
92
93            if cafile or capath:
94                self.ssl_context.load_verify_locations(cafile=cafile,
95                                                       capath=capath)
96
97            if certfile:
98                self.ssl_context.load_cert_chain(certfile, keyfile=keyfile)
99
100            if not validate:
101                self.ssl_context.check_hostname = False
102                self.ssl_context.verify_mode = ssl.CERT_NONE
103        else:
104            self.ssl_context = None
105            self.server_hostname = None
106
107    def _init_sock(self):
108        if self.unix_socket:
109            _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
110        else:
111            _sock = socket.socket(self.socket_family, socket.SOCK_STREAM)
112            _sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
113
114        # socket options
115        linger = struct.pack('ii', 0, 0)
116        _sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
117        _sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
118
119        self.raw_sock = _sock
120
121    def set_handle(self, sock):
122        self.raw_sock = sock
123
124    def set_timeout(self, ms):
125        """Backward compat api, will bind the timeout to both connect_timeout
126        and socket_timeout.
127        """
128        self.socket_timeout = ms / 1000 if (ms and ms > 0) else None
129        self.connect_timeout = self.socket_timeout
130
131        if self.raw_sock is not None:
132            self.raw_sock.settimeout(self.socket_timeout)
133
134    def is_open(self):
135        return bool(self.raw_sock)
136
137    @asyncio.coroutine
138    def open(self):
139        self._init_sock()
140
141        addr = self.unix_socket or (self.host, self.port)
142
143        try:
144            if self.connect_timeout:
145                self.raw_sock.settimeout(self.connect_timeout)
146
147            self.raw_sock.connect(addr)
148
149            if self.socket_timeout:
150                self.raw_sock.settimeout(self.socket_timeout)
151
152            kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context}
153            if self.server_hostname:
154                kwargs['server_hostname'] = self.server_hostname
155
156            self.reader, self.writer = yield from asyncio.wait_for(
157                self.sock_factory(**kwargs),
158                self.socket_timeout
159            )
160
161        except (socket.error, OSError):
162            raise TTransportException(
163                type=TTransportException.NOT_OPEN,
164                message="Could not connect to %s" % str(addr))
165
166    @asyncio.coroutine
167    def read(self, sz):
168        try:
169            buff = yield from asyncio.wait_for(
170                self.reader.read(sz),
171                self.connect_timeout
172            )
173        except socket.error as e:
174            if (e.args[0] == errno.ECONNRESET and
175                    (sys.platform == 'darwin' or
176                     sys.platform.startswith('freebsd'))):
177                # freebsd and Mach don't follow POSIX semantic of recv
178                # and fail with ECONNRESET if peer performed shutdown.
179                # See corresponding comment and code in TSocket::read()
180                # in lib/cpp/src/transport/TSocket.cpp.
181                self.close()
182                # Trigger the check to raise the END_OF_FILE exception below.
183                buff = ''
184            else:
185                raise
186
187        if len(buff) == 0:
188            raise TTransportException(type=TTransportException.END_OF_FILE,
189                                      message='TSocket read 0 bytes')
190        return buff
191
192    def write(self, buff):
193        self.writer.write(buff)
194
195    @asyncio.coroutine
196    def flush(self):
197        yield from asyncio.wait_for(self.writer.drain(), self.connect_timeout)
198
199    def close(self):
200        if not self.raw_sock:
201            return
202
203        try:
204            self.writer.close()
205            self.raw_sock.close()
206            self.raw_sock = None
207        except (socket.error, OSError):
208            pass
209
210
211class TAsyncServerSocket(object):
212    """Socket implementation for server side."""
213
214    def __init__(self, host=None, port=None, unix_socket=None,
215                 socket_family=socket.AF_INET, client_timeout=3000,
216                 backlog=128, ssl_context=None, certfile=None, keyfile=None,
217                 ciphers=RESTRICTED_SERVER_CIPHERS):
218        """Initialize a TServerSocket
219
220        TSocket can be initialized in 2 ways:
221        * host + port. can configure to use AF_INET/AF_INET6
222        * unix_socket
223
224        @param host(str)    The host to connect to
225        @param port(int)    The (TCP) port to connect to
226        @param unix_socket(str) The filename of a unix socket to connect to
227        @param socket_family(str) socket.AF_INET or socket.AF_INET6. only
228            take effect when using host/port
229        @param client_timeout   client socket timeout
230        @param backlog          backlog for server socket
231        @param certfile(str)        The server cert pem filename
232        @param keyfile(str)         The server cert key filename
233        @param ciphers(list<str>)   The cipher suites to allow
234        @param ssl_context(SSLContext)  Customize the SSLContext, can be used
235            to persist SSLContext object. Caution it's easy to get wrong, only
236            use if you know what you're doing.
237        """
238        if unix_socket:
239            self.unix_socket = unix_socket
240            self.host = None
241            self.port = None
242            self.sock_factory = asyncio.start_unix_server
243        else:
244            self.unix_socket = None
245            self.host = host
246            self.port = port
247            self.sock_factory = asyncio.start_server
248
249        self.socket_family = socket_family
250        self.client_timeout = client_timeout / 1000 if client_timeout else None
251        self.backlog = backlog
252
253        if ssl_context:
254            self.ssl_context = ssl_context
255        elif certfile:
256            if not os.access(certfile, os.R_OK):
257                raise IOError('No such certfile found: %s' % certfile)
258
259            self.ssl_context = create_thriftpy_context(server_side=True,
260                                                       ciphers=ciphers)
261            self.ssl_context.load_cert_chain(certfile, keyfile=keyfile)
262        else:
263            self.ssl_context = None
264
265    def _init_sock(self):
266        if self.unix_socket:
267            # try remove the sock file it already exists
268            _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
269            try:
270                _sock.connect(self.unix_socket)
271            except (socket.error, OSError) as err:
272                if err.args[0] == errno.ECONNREFUSED:
273                    os.unlink(self.unix_socket)
274        else:
275            _sock = socket.socket(self.socket_family, socket.SOCK_STREAM)
276
277        _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
278        if hasattr(socket, "SO_REUSEPORT"):
279            try:
280                _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
281            except socket.error as err:
282                if err[0] in (errno.ENOPROTOOPT, errno.EINVAL):
283                    pass
284                else:
285                    raise
286        _sock.settimeout(None)
287        self.raw_sock = _sock
288
289    def listen(self):
290        self._init_sock()
291
292        addr = self.unix_socket or (self.host, self.port)
293        self.raw_sock.bind(addr)
294        self.raw_sock.listen(self.backlog)
295
296    @asyncio.coroutine
297    def accept(self, callback):
298        server = yield from self.sock_factory(
299            lambda reader, writer: asyncio.wait_for(
300                callback(StreamHandler(reader, writer)),
301                self.client_timeout
302            ),
303            sock=self.raw_sock,
304            ssl=self.ssl_context
305        )
306        return server
307
308    def close(self):
309        if not self.raw_sock:
310            return
311
312        try:
313            self.raw_sock.shutdown(socket.SHUT_RDWR)
314            self.raw_sock.close()
315        except (socket.error, OSError):
316            pass
317
318
319class StreamHandler(object):
320    def __init__(self, reader, writer):
321        self.reader, self.writer = reader, writer
322
323    @asyncio.coroutine
324    def read(self, sz):
325        try:
326            buff = yield from self.reader.read(sz)
327        except socket.error as e:
328            if (e.args[0] == errno.ECONNRESET and
329                    (sys.platform == 'darwin' or
330                     sys.platform.startswith('freebsd'))):
331                # freebsd and Mach don't follow POSIX semantic of recv
332                # and fail with ECONNRESET if peer performed shutdown.
333                # See corresponding comment and code in TSocket::read()
334                # in lib/cpp/src/transport/TSocket.cpp.
335                self.close()
336                # Trigger the check to raise the END_OF_FILE exception below.
337                buff = ''
338            else:
339                raise
340
341        if len(buff) == 0:
342            raise TTransportException(type=TTransportException.END_OF_FILE,
343                                      message='TSocket read 0 bytes')
344        return buff
345
346    def write(self, buff):
347        self.writer.write(buff)
348
349    @asyncio.coroutine
350    def flush(self):
351        yield from self.writer.drain()
352
353    def close(self):
354        try:
355            self.writer.close()
356        except (socket.error, OSError):
357            pass
358
359    @asyncio.coroutine
360    def open(self):
361        pass
362