1# -*- coding: utf-8 -*- 2 3from __future__ import absolute_import, division 4 5import ssl 6import asyncio 7import errno 8import os 9import socket 10import struct 11import sys 12 13from thriftpy2.transport import TTransportException 14from thriftpy2.transport._ssl import ( 15 create_thriftpy_context, 16 RESTRICTED_SERVER_CIPHERS, 17 DEFAULT_CIPHERS 18) 19 20 21class TAsyncSocket(object): 22 """Socket implementation for client side.""" 23 24 def __init__(self, host=None, port=None, unix_socket=None, 25 sock=None, socket_family=socket.AF_INET, 26 socket_timeout=3000, connect_timeout=None, 27 ssl_context=None, validate=True, 28 cafile=None, capath=None, certfile=None, keyfile=None, 29 ciphers=DEFAULT_CIPHERS): 30 """Initialize a TSocket 31 32 TSocket can be initialized in 3 ways: 33 * host + port. can configure to use AF_INET/AF_INET6 34 * unix_socket 35 * socket. should pass already opened socket here. 36 37 @param host(str) The host to connect to. 38 @param port(int) The (TCP) port to connect to. 39 @param unix_socket(str) The filename of a unix socket to connect to. 40 @param sock(socket) Initialize with opened socket directly. 41 If this param used, the host, port and unix_socket params will 42 be ignored. 43 @param socket_family(str) socket.AF_INET or socket.AF_INET6. only 44 take effect when using host/port 45 @param socket_timeout socket timeout in ms 46 @param connect_timeout connect timeout in ms, only used in 47 connection, will be set to socket_timeout if not set. 48 @param validate(bool) Set to False to disable SSL certificate 49 validation and hostname validation. Default enabled. 50 @param cafile(str) Path to a file of concatenated CA 51 certificates in PEM format. 52 @param capath(str) path to a directory containing several CA 53 certificates in PEM format, following an OpenSSL specific layout. 54 @param certfile(str) The certfile string must be the path to a 55 single file in PEM format containing the certificate as well as 56 any number of CA certificates needed to establish the 57 certificate’s authenticity. 58 @param keyfile(str) The keyfile string, if not present, 59 the private key will be taken from certfile as well. 60 @param ciphers(list<str>) The cipher suites to allow 61 @param ssl_context(SSLContext) Customize the SSLContext, can be used 62 to persist SSLContext object. Caution it's easy to get wrong, only 63 use if you know what you're doing. 64 """ 65 if sock: 66 self.raw_sock = sock 67 elif unix_socket: 68 self.unix_socket = unix_socket 69 self.host = None 70 self.port = None 71 self.raw_sock = None 72 self.sock_factory = asyncio.open_unix_connection 73 else: 74 self.unix_socket = None 75 self.host = host 76 self.port = port 77 self.raw_sock = None 78 self.sock_factory = asyncio.open_connection 79 80 self.socket_family = socket_family 81 self.socket_timeout = socket_timeout / 1000 if socket_timeout else None 82 self.connect_timeout = connect_timeout / 1000 if connect_timeout \ 83 else self.socket_timeout 84 85 if ssl_context: 86 self.ssl_context = ssl_context 87 self.server_hostname = host 88 elif certfile or keyfile: 89 self.server_hostname = host 90 self.ssl_context = create_thriftpy_context(server_side=False, 91 ciphers=ciphers) 92 93 if cafile or capath: 94 self.ssl_context.load_verify_locations(cafile=cafile, 95 capath=capath) 96 97 if certfile: 98 self.ssl_context.load_cert_chain(certfile, keyfile=keyfile) 99 100 if not validate: 101 self.ssl_context.check_hostname = False 102 self.ssl_context.verify_mode = ssl.CERT_NONE 103 else: 104 self.ssl_context = None 105 self.server_hostname = None 106 107 def _init_sock(self): 108 if self.unix_socket: 109 _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 110 else: 111 _sock = socket.socket(self.socket_family, socket.SOCK_STREAM) 112 _sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 113 114 # socket options 115 linger = struct.pack('ii', 0, 0) 116 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger) 117 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 118 119 self.raw_sock = _sock 120 121 def set_handle(self, sock): 122 self.raw_sock = sock 123 124 def set_timeout(self, ms): 125 """Backward compat api, will bind the timeout to both connect_timeout 126 and socket_timeout. 127 """ 128 self.socket_timeout = ms / 1000 if (ms and ms > 0) else None 129 self.connect_timeout = self.socket_timeout 130 131 if self.raw_sock is not None: 132 self.raw_sock.settimeout(self.socket_timeout) 133 134 def is_open(self): 135 return bool(self.raw_sock) 136 137 @asyncio.coroutine 138 def open(self): 139 self._init_sock() 140 141 addr = self.unix_socket or (self.host, self.port) 142 143 try: 144 if self.connect_timeout: 145 self.raw_sock.settimeout(self.connect_timeout) 146 147 self.raw_sock.connect(addr) 148 149 if self.socket_timeout: 150 self.raw_sock.settimeout(self.socket_timeout) 151 152 kwargs = {'sock': self.raw_sock, 'ssl': self.ssl_context} 153 if self.server_hostname: 154 kwargs['server_hostname'] = self.server_hostname 155 156 self.reader, self.writer = yield from asyncio.wait_for( 157 self.sock_factory(**kwargs), 158 self.socket_timeout 159 ) 160 161 except (socket.error, OSError): 162 raise TTransportException( 163 type=TTransportException.NOT_OPEN, 164 message="Could not connect to %s" % str(addr)) 165 166 @asyncio.coroutine 167 def read(self, sz): 168 try: 169 buff = yield from asyncio.wait_for( 170 self.reader.read(sz), 171 self.connect_timeout 172 ) 173 except socket.error as e: 174 if (e.args[0] == errno.ECONNRESET and 175 (sys.platform == 'darwin' or 176 sys.platform.startswith('freebsd'))): 177 # freebsd and Mach don't follow POSIX semantic of recv 178 # and fail with ECONNRESET if peer performed shutdown. 179 # See corresponding comment and code in TSocket::read() 180 # in lib/cpp/src/transport/TSocket.cpp. 181 self.close() 182 # Trigger the check to raise the END_OF_FILE exception below. 183 buff = '' 184 else: 185 raise 186 187 if len(buff) == 0: 188 raise TTransportException(type=TTransportException.END_OF_FILE, 189 message='TSocket read 0 bytes') 190 return buff 191 192 def write(self, buff): 193 self.writer.write(buff) 194 195 @asyncio.coroutine 196 def flush(self): 197 yield from asyncio.wait_for(self.writer.drain(), self.connect_timeout) 198 199 def close(self): 200 if not self.raw_sock: 201 return 202 203 try: 204 self.writer.close() 205 self.raw_sock.close() 206 self.raw_sock = None 207 except (socket.error, OSError): 208 pass 209 210 211class TAsyncServerSocket(object): 212 """Socket implementation for server side.""" 213 214 def __init__(self, host=None, port=None, unix_socket=None, 215 socket_family=socket.AF_INET, client_timeout=3000, 216 backlog=128, ssl_context=None, certfile=None, keyfile=None, 217 ciphers=RESTRICTED_SERVER_CIPHERS): 218 """Initialize a TServerSocket 219 220 TSocket can be initialized in 2 ways: 221 * host + port. can configure to use AF_INET/AF_INET6 222 * unix_socket 223 224 @param host(str) The host to connect to 225 @param port(int) The (TCP) port to connect to 226 @param unix_socket(str) The filename of a unix socket to connect to 227 @param socket_family(str) socket.AF_INET or socket.AF_INET6. only 228 take effect when using host/port 229 @param client_timeout client socket timeout 230 @param backlog backlog for server socket 231 @param certfile(str) The server cert pem filename 232 @param keyfile(str) The server cert key filename 233 @param ciphers(list<str>) The cipher suites to allow 234 @param ssl_context(SSLContext) Customize the SSLContext, can be used 235 to persist SSLContext object. Caution it's easy to get wrong, only 236 use if you know what you're doing. 237 """ 238 if unix_socket: 239 self.unix_socket = unix_socket 240 self.host = None 241 self.port = None 242 self.sock_factory = asyncio.start_unix_server 243 else: 244 self.unix_socket = None 245 self.host = host 246 self.port = port 247 self.sock_factory = asyncio.start_server 248 249 self.socket_family = socket_family 250 self.client_timeout = client_timeout / 1000 if client_timeout else None 251 self.backlog = backlog 252 253 if ssl_context: 254 self.ssl_context = ssl_context 255 elif certfile: 256 if not os.access(certfile, os.R_OK): 257 raise IOError('No such certfile found: %s' % certfile) 258 259 self.ssl_context = create_thriftpy_context(server_side=True, 260 ciphers=ciphers) 261 self.ssl_context.load_cert_chain(certfile, keyfile=keyfile) 262 else: 263 self.ssl_context = None 264 265 def _init_sock(self): 266 if self.unix_socket: 267 # try remove the sock file it already exists 268 _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 269 try: 270 _sock.connect(self.unix_socket) 271 except (socket.error, OSError) as err: 272 if err.args[0] == errno.ECONNREFUSED: 273 os.unlink(self.unix_socket) 274 else: 275 _sock = socket.socket(self.socket_family, socket.SOCK_STREAM) 276 277 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 278 if hasattr(socket, "SO_REUSEPORT"): 279 try: 280 _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) 281 except socket.error as err: 282 if err[0] in (errno.ENOPROTOOPT, errno.EINVAL): 283 pass 284 else: 285 raise 286 _sock.settimeout(None) 287 self.raw_sock = _sock 288 289 def listen(self): 290 self._init_sock() 291 292 addr = self.unix_socket or (self.host, self.port) 293 self.raw_sock.bind(addr) 294 self.raw_sock.listen(self.backlog) 295 296 @asyncio.coroutine 297 def accept(self, callback): 298 server = yield from self.sock_factory( 299 lambda reader, writer: asyncio.wait_for( 300 callback(StreamHandler(reader, writer)), 301 self.client_timeout 302 ), 303 sock=self.raw_sock, 304 ssl=self.ssl_context 305 ) 306 return server 307 308 def close(self): 309 if not self.raw_sock: 310 return 311 312 try: 313 self.raw_sock.shutdown(socket.SHUT_RDWR) 314 self.raw_sock.close() 315 except (socket.error, OSError): 316 pass 317 318 319class StreamHandler(object): 320 def __init__(self, reader, writer): 321 self.reader, self.writer = reader, writer 322 323 @asyncio.coroutine 324 def read(self, sz): 325 try: 326 buff = yield from self.reader.read(sz) 327 except socket.error as e: 328 if (e.args[0] == errno.ECONNRESET and 329 (sys.platform == 'darwin' or 330 sys.platform.startswith('freebsd'))): 331 # freebsd and Mach don't follow POSIX semantic of recv 332 # and fail with ECONNRESET if peer performed shutdown. 333 # See corresponding comment and code in TSocket::read() 334 # in lib/cpp/src/transport/TSocket.cpp. 335 self.close() 336 # Trigger the check to raise the END_OF_FILE exception below. 337 buff = '' 338 else: 339 raise 340 341 if len(buff) == 0: 342 raise TTransportException(type=TTransportException.END_OF_FILE, 343 message='TSocket read 0 bytes') 344 return buff 345 346 def write(self, buff): 347 self.writer.write(buff) 348 349 @asyncio.coroutine 350 def flush(self): 351 yield from self.writer.drain() 352 353 def close(self): 354 try: 355 self.writer.close() 356 except (socket.error, OSError): 357 pass 358 359 @asyncio.coroutine 360 def open(self): 361 pass 362