1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import errno 21import logging 22import os 23import socket 24import sys 25 26from .TTransport import TTransportBase, TTransportException, TServerTransportBase 27 28logger = logging.getLogger(__name__) 29 30 31class TSocketBase(TTransportBase): 32 def _resolveAddr(self): 33 if self._unix_socket is not None: 34 return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, 35 self._unix_socket)] 36 else: 37 return socket.getaddrinfo(self.host, 38 self.port, 39 self._socket_family, 40 socket.SOCK_STREAM, 41 0, 42 socket.AI_PASSIVE) 43 44 def close(self): 45 if self.handle: 46 self.handle.close() 47 self.handle = None 48 49 50class TSocket(TSocketBase): 51 """Socket implementation of TTransport base.""" 52 53 def __init__(self, host='localhost', port=9090, unix_socket=None, 54 socket_family=socket.AF_UNSPEC, 55 socket_keepalive=False): 56 """Initialize a TSocket 57 58 @param host(str) The host to connect to. 59 @param port(int) The (TCP) port to connect to. 60 @param unix_socket(str) The filename of a unix socket to connect to. 61 (host and port will be ignored.) 62 @param socket_family(int) The socket family to use with this socket. 63 @param socket_keepalive(bool) enable TCP keepalive, default off. 64 """ 65 self.host = host 66 self.port = port 67 self.handle = None 68 self._unix_socket = unix_socket 69 self._timeout = None 70 self._socket_family = socket_family 71 self._socket_keepalive = socket_keepalive 72 73 def setHandle(self, h): 74 self.handle = h 75 76 def isOpen(self): 77 if self.handle is None: 78 return False 79 80 # this lets us cheaply see if the other end of the socket is still 81 # connected. if disconnected, we'll get EOF back (expressed as zero 82 # bytes of data) otherwise we'll get one byte or an error indicating 83 # we'd have to block for data. 84 # 85 # note that we're not doing this with socket.MSG_DONTWAIT because 1) 86 # it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us 87 # when timeout is non-zero. 88 original_timeout = self.handle.gettimeout() 89 try: 90 self.handle.settimeout(0) 91 try: 92 peeked_bytes = self.handle.recv(1, socket.MSG_PEEK) 93 except (socket.error, OSError) as exc: # on modern python this is just BlockingIOError 94 if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN): 95 return True 96 return False 97 finally: 98 self.handle.settimeout(original_timeout) 99 100 # the length will be zero if we got EOF (indicating connection closed) 101 return len(peeked_bytes) == 1 102 103 def setTimeout(self, ms): 104 if ms is None: 105 self._timeout = None 106 else: 107 self._timeout = ms / 1000.0 108 109 if self.handle is not None: 110 self.handle.settimeout(self._timeout) 111 112 def _do_open(self, family, socktype): 113 return socket.socket(family, socktype) 114 115 @property 116 def _address(self): 117 return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port) 118 119 def open(self): 120 if self.handle: 121 raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open") 122 try: 123 addrs = self._resolveAddr() 124 except socket.gaierror as gai: 125 msg = 'failed to resolve sockaddr for ' + str(self._address) 126 logger.exception(msg) 127 raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai) 128 for family, socktype, _, _, sockaddr in addrs: 129 handle = self._do_open(family, socktype) 130 131 # TCP_KEEPALIVE 132 if self._socket_keepalive: 133 handle.setsockopt(socket.IPPROTO_TCP, socket.SO_KEEPALIVE, 1) 134 135 handle.settimeout(self._timeout) 136 try: 137 handle.connect(sockaddr) 138 self.handle = handle 139 return 140 except socket.error: 141 handle.close() 142 logger.info('Could not connect to %s', sockaddr, exc_info=True) 143 msg = 'Could not connect to any of %s' % list(map(lambda a: a[4], 144 addrs)) 145 logger.error(msg) 146 raise TTransportException(type=TTransportException.NOT_OPEN, message=msg) 147 148 def read(self, sz): 149 try: 150 buff = self.handle.recv(sz) 151 except socket.error as e: 152 if (e.args[0] == errno.ECONNRESET and 153 (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): 154 # freebsd and Mach don't follow POSIX semantic of recv 155 # and fail with ECONNRESET if peer performed shutdown. 156 # See corresponding comment and code in TSocket::read() 157 # in lib/cpp/src/transport/TSocket.cpp. 158 self.close() 159 # Trigger the check to raise the END_OF_FILE exception below. 160 buff = '' 161 elif e.args[0] == errno.ETIMEDOUT: 162 raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e) 163 else: 164 raise TTransportException(message="unexpected exception", inner=e) 165 if len(buff) == 0: 166 raise TTransportException(type=TTransportException.END_OF_FILE, 167 message='TSocket read 0 bytes') 168 return buff 169 170 def write(self, buff): 171 if not self.handle: 172 raise TTransportException(type=TTransportException.NOT_OPEN, 173 message='Transport not open') 174 sent = 0 175 have = len(buff) 176 while sent < have: 177 try: 178 plus = self.handle.send(buff) 179 if plus == 0: 180 raise TTransportException(type=TTransportException.END_OF_FILE, 181 message='TSocket sent 0 bytes') 182 sent += plus 183 buff = buff[plus:] 184 except socket.error as e: 185 raise TTransportException(message="unexpected exception", inner=e) 186 187 def flush(self): 188 pass 189 190 191class TServerSocket(TSocketBase, TServerTransportBase): 192 """Socket implementation of TServerTransport base.""" 193 194 def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): 195 self.host = host 196 self.port = port 197 self._unix_socket = unix_socket 198 self._socket_family = socket_family 199 self.handle = None 200 self._backlog = 128 201 202 def setBacklog(self, backlog=None): 203 if not self.handle: 204 self._backlog = backlog 205 else: 206 # We cann't update backlog when it is already listening, since the 207 # handle has been created. 208 logger.warn('You have to set backlog before listen.') 209 210 def listen(self): 211 res0 = self._resolveAddr() 212 socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family 213 for res in res0: 214 if res[0] is socket_family or res is res0[-1]: 215 break 216 217 # We need remove the old unix socket if the file exists and 218 # nobody is listening on it. 219 if self._unix_socket: 220 tmp = socket.socket(res[0], res[1]) 221 try: 222 tmp.connect(res[4]) 223 except socket.error as err: 224 eno, message = err.args 225 if eno == errno.ECONNREFUSED: 226 os.unlink(res[4]) 227 228 self.handle = socket.socket(res[0], res[1]) 229 self.handle.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 230 if hasattr(self.handle, 'settimeout'): 231 self.handle.settimeout(None) 232 self.handle.bind(res[4]) 233 self.handle.listen(self._backlog) 234 235 def accept(self): 236 client, addr = self.handle.accept() 237 result = TSocket() 238 result.setHandle(client) 239 return result 240