1# -*- coding: utf-8 -*-
2
3from __future__ import absolute_import, division
4
5import errno
6import os
7import socket
8import struct
9import sys
10
11from . import TTransportException
12
13
14class TSocket(object):
15    """Socket implementation for client side."""
16
17    def __init__(self, host=None, port=None, unix_socket=None,
18                 sock=None, socket_family=socket.AF_INET,
19                 socket_timeout=3000, connect_timeout=None):
20        """Initialize a TSocket
21
22        TSocket can be initialized in 3 ways:
23        * host + port. can configure to use AF_INET/AF_INET6
24        * unix_socket
25        * socket. should pass already opened socket here.
26
27        @param host(str)    The host to connect to.
28        @param port(int)    The (TCP) port to connect to.
29        @param unix_socket(str) The filename of a unix socket to connect to.
30        @param sock(socket)     Initialize with opened socket directly.
31            If this param used, the host, port and unix_socket params will
32            be ignored.
33        @param socket_family(str) socket.AF_INET or socket.AF_INET6. only
34            take effect when using host/port
35        @param socket_timeout   socket timeout in ms
36        @param connect_timeout  connect timeout in ms, only used in
37            connection, will be set to socket_timeout if not set.
38        """
39        if sock:
40            self.sock = sock
41        elif unix_socket:
42            self.unix_socket = unix_socket
43            self.host = None
44            self.port = None
45            self.sock = None
46        else:
47            self.unix_socket = None
48            self.host = host
49            self.port = port
50            self.sock = None
51
52        self.socket_family = socket_family
53        self.socket_timeout = socket_timeout / 1000 if socket_timeout else None
54        self.connect_timeout = connect_timeout / 1000 if connect_timeout \
55            else self.socket_timeout
56
57    def _init_sock(self):
58        if self.unix_socket:
59            _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
60        else:
61            _sock = socket.socket(self.socket_family, socket.SOCK_STREAM)
62            _sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
63
64        # socket options
65        linger = struct.pack('ii', 0, 0)
66        _sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger)
67        _sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
68
69        self.sock = _sock
70
71    def set_handle(self, sock):
72        self.sock = sock
73
74    def set_timeout(self, ms):
75        """Backward compat api, will bind the timeout to both connect_timeout
76        and socket_timeout.
77        """
78        self.socket_timeout = ms / 1000 if (ms and ms > 0) else None
79        self.connect_timeout = self.socket_timeout
80
81        if self.sock is not None:
82            self.sock.settimeout(self.socket_timeout)
83
84    def is_open(self):
85        return bool(self.sock)
86
87    def open(self):
88        self._init_sock()
89
90        addr = self.unix_socket or (self.host, self.port)
91
92        try:
93            if self.connect_timeout:
94                self.sock.settimeout(self.connect_timeout)
95
96            self.sock.connect(addr)
97
98            if self.socket_timeout:
99                self.sock.settimeout(self.socket_timeout)
100
101        except (socket.error, OSError):
102            raise TTransportException(
103                type=TTransportException.NOT_OPEN,
104                message="Could not connect to %s" % str(addr))
105
106    def read(self, sz):
107        try:
108            buff = self.sock.recv(sz)
109        except socket.error as e:
110            if (e.args[0] == errno.ECONNRESET and
111                    (sys.platform == 'darwin' or
112                     sys.platform.startswith('freebsd'))):
113                # freebsd and Mach don't follow POSIX semantic of recv
114                # and fail with ECONNRESET if peer performed shutdown.
115                # See corresponding comment and code in TSocket::read()
116                # in lib/cpp/src/transport/TSocket.cpp.
117                self.close()
118                # Trigger the check to raise the END_OF_FILE exception below.
119                buff = ''
120            else:
121                raise
122
123        if len(buff) == 0:
124            raise TTransportException(type=TTransportException.END_OF_FILE,
125                                      message='TSocket read 0 bytes')
126        return buff
127
128    def write(self, buff):
129        self.sock.sendall(buff)
130
131    def flush(self):
132        pass
133
134    def close(self):
135        if not self.sock:
136            return
137
138        try:
139            self.sock.shutdown(socket.SHUT_RDWR)
140            self.sock.close()
141        except (socket.error, OSError):
142            pass
143
144
145class TServerSocket(object):
146    """Socket implementation for server side."""
147
148    def __init__(self, host=None, port=None, unix_socket=None,
149                 socket_family=socket.AF_INET, client_timeout=3000,
150                 backlog=128):
151        """Initialize a TServerSocket
152
153        TSocket can be initialized in 2 ways:
154        * host + port. can configure to use AF_INET/AF_INET6
155        * unix_socket
156
157        @param host(str)    The host to connect to
158        @param port(int)    The (TCP) port to connect to
159        @param unix_socket(str) The filename of a unix socket to connect to
160        @param socket_family(str) socket.AF_INET or socket.AF_INET6. only
161            take effect when using host/port
162        @param client_timeout   client socket timeout
163        @param backlog          backlog for server socket
164        """
165
166        if unix_socket:
167            self.unix_socket = unix_socket
168            self.host = None
169            self.port = None
170        else:
171            self.unix_socket = None
172            self.host = host
173            self.port = port
174
175        self.socket_family = socket_family
176        self.client_timeout = client_timeout / 1000 if client_timeout else None
177        self.backlog = backlog
178
179    def _init_sock(self):
180        if self.unix_socket:
181            # try remove the sock file it already exists
182            _sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
183            try:
184                _sock.connect(self.unix_socket)
185            except (socket.error, OSError) as err:
186                if err.args[0] == errno.ECONNREFUSED:
187                    os.unlink(self.unix_socket)
188        else:
189            _sock = socket.socket(self.socket_family, socket.SOCK_STREAM)
190
191        _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
192        if hasattr(socket, "SO_REUSEPORT"):
193            _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
194        _sock.settimeout(None)
195        self.sock = _sock
196
197    def listen(self):
198        self._init_sock()
199
200        addr = self.unix_socket or (self.host, self.port)
201        self.sock.bind(addr)
202        self.sock.listen(self.backlog)
203
204    def accept(self):
205        client, _ = self.sock.accept()
206        client.settimeout(self.client_timeout)
207        return TSocket(sock=client)
208
209    def close(self):
210        if not self.sock:
211            return
212
213        try:
214            self.sock.shutdown(socket.SHUT_RDWR)
215            self.sock.close()
216        except (socket.error, OSError):
217            pass
218