1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19"""Implementation of non-blocking server.
20
21The main idea of the server is to receive and send requests
22only from the main thread.
23
24The thread poool should be sized for concurrent tasks, not
25maximum connections
26"""
27
28import logging
29import select
30import socket
31import struct
32import threading
33
34from collections import deque
35from six.moves import queue
36
37from thrift.transport import TTransport
38from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
39
40__all__ = ['TNonblockingServer']
41
42logger = logging.getLogger(__name__)
43
44
45class Worker(threading.Thread):
46    """Worker is a small helper to process incoming connection."""
47
48    def __init__(self, queue):
49        threading.Thread.__init__(self)
50        self.queue = queue
51
52    def run(self):
53        """Process queries from task queue, stop if processor is None."""
54        while True:
55            try:
56                processor, iprot, oprot, otrans, callback = self.queue.get()
57                if processor is None:
58                    break
59                processor.process(iprot, oprot)
60                callback(True, otrans.getvalue())
61            except Exception:
62                logger.exception("Exception while processing request", exc_info=True)
63                callback(False, b'')
64
65
66WAIT_LEN = 0
67WAIT_MESSAGE = 1
68WAIT_PROCESS = 2
69SEND_ANSWER = 3
70CLOSED = 4
71
72
73def locked(func):
74    """Decorator which locks self.lock."""
75    def nested(self, *args, **kwargs):
76        self.lock.acquire()
77        try:
78            return func(self, *args, **kwargs)
79        finally:
80            self.lock.release()
81    return nested
82
83
84def socket_exception(func):
85    """Decorator close object on socket.error."""
86    def read(self, *args, **kwargs):
87        try:
88            return func(self, *args, **kwargs)
89        except socket.error:
90            logger.debug('ignoring socket exception', exc_info=True)
91            self.close()
92    return read
93
94
95class Message(object):
96    def __init__(self, offset, len_, header):
97        self.offset = offset
98        self.len = len_
99        self.buffer = None
100        self.is_header = header
101
102    @property
103    def end(self):
104        return self.offset + self.len
105
106
107class Connection(object):
108    """Basic class is represented connection.
109
110    It can be in state:
111        WAIT_LEN --- connection is reading request len.
112        WAIT_MESSAGE --- connection is reading request.
113        WAIT_PROCESS --- connection has just read whole request and
114                         waits for call ready routine.
115        SEND_ANSWER --- connection is sending answer string (including length
116                        of answer).
117        CLOSED --- socket was closed and connection should be deleted.
118    """
119    def __init__(self, new_socket, wake_up):
120        self.socket = new_socket
121        self.socket.setblocking(False)
122        self.status = WAIT_LEN
123        self.len = 0
124        self.received = deque()
125        self._reading = Message(0, 4, True)
126        self._rbuf = b''
127        self._wbuf = b''
128        self.lock = threading.Lock()
129        self.wake_up = wake_up
130        self.remaining = False
131
132    @socket_exception
133    def read(self):
134        """Reads data from stream and switch state."""
135        assert self.status in (WAIT_LEN, WAIT_MESSAGE)
136        assert not self.received
137        buf_size = 8192
138        first = True
139        done = False
140        while not done:
141            read = self.socket.recv(buf_size)
142            rlen = len(read)
143            done = rlen < buf_size
144            self._rbuf += read
145            if first and rlen == 0:
146                if self.status != WAIT_LEN or self._rbuf:
147                    logger.error('could not read frame from socket')
148                else:
149                    logger.debug('read zero length. client might have disconnected')
150                self.close()
151            while len(self._rbuf) >= self._reading.end:
152                if self._reading.is_header:
153                    mlen, = struct.unpack('!i', self._rbuf[:4])
154                    self._reading = Message(self._reading.end, mlen, False)
155                    self.status = WAIT_MESSAGE
156                else:
157                    self._reading.buffer = self._rbuf
158                    self.received.append(self._reading)
159                    self._rbuf = self._rbuf[self._reading.end:]
160                    self._reading = Message(0, 4, True)
161            first = False
162            if self.received:
163                self.status = WAIT_PROCESS
164                break
165        self.remaining = not done
166
167    @socket_exception
168    def write(self):
169        """Writes data from socket and switch state."""
170        assert self.status == SEND_ANSWER
171        sent = self.socket.send(self._wbuf)
172        if sent == len(self._wbuf):
173            self.status = WAIT_LEN
174            self._wbuf = b''
175            self.len = 0
176        else:
177            self._wbuf = self._wbuf[sent:]
178
179    @locked
180    def ready(self, all_ok, message):
181        """Callback function for switching state and waking up main thread.
182
183        This function is the only function witch can be called asynchronous.
184
185        The ready can switch Connection to three states:
186            WAIT_LEN if request was oneway.
187            SEND_ANSWER if request was processed in normal way.
188            CLOSED if request throws unexpected exception.
189
190        The one wakes up main thread.
191        """
192        assert self.status == WAIT_PROCESS
193        if not all_ok:
194            self.close()
195            self.wake_up()
196            return
197        self.len = 0
198        if len(message) == 0:
199            # it was a oneway request, do not write answer
200            self._wbuf = b''
201            self.status = WAIT_LEN
202        else:
203            self._wbuf = struct.pack('!i', len(message)) + message
204            self.status = SEND_ANSWER
205        self.wake_up()
206
207    @locked
208    def is_writeable(self):
209        """Return True if connection should be added to write list of select"""
210        return self.status == SEND_ANSWER
211
212    # it's not necessary, but...
213    @locked
214    def is_readable(self):
215        """Return True if connection should be added to read list of select"""
216        return self.status in (WAIT_LEN, WAIT_MESSAGE)
217
218    @locked
219    def is_closed(self):
220        """Returns True if connection is closed."""
221        return self.status == CLOSED
222
223    def fileno(self):
224        """Returns the file descriptor of the associated socket."""
225        return self.socket.fileno()
226
227    def close(self):
228        """Closes connection"""
229        self.status = CLOSED
230        self.socket.close()
231
232
233class TNonblockingServer(object):
234    """Non-blocking server."""
235
236    def __init__(self,
237                 processor,
238                 lsocket,
239                 inputProtocolFactory=None,
240                 outputProtocolFactory=None,
241                 threads=10):
242        self.processor = processor
243        self.socket = lsocket
244        self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
245        self.out_protocol = outputProtocolFactory or self.in_protocol
246        self.threads = int(threads)
247        self.clients = {}
248        self.tasks = queue.Queue()
249        self._read, self._write = socket.socketpair()
250        self.prepared = False
251        self._stop = False
252
253    def setNumThreads(self, num):
254        """Set the number of worker threads that should be created."""
255        # implement ThreadPool interface
256        assert not self.prepared, "Can't change number of threads after start"
257        self.threads = num
258
259    def prepare(self):
260        """Prepares server for serve requests."""
261        if self.prepared:
262            return
263        self.socket.listen()
264        for _ in range(self.threads):
265            thread = Worker(self.tasks)
266            thread.setDaemon(True)
267            thread.start()
268        self.prepared = True
269
270    def wake_up(self):
271        """Wake up main thread.
272
273        The server usually waits in select call in we should terminate one.
274        The simplest way is using socketpair.
275
276        Select always wait to read from the first socket of socketpair.
277
278        In this case, we can just write anything to the second socket from
279        socketpair.
280        """
281        self._write.send(b'1')
282
283    def stop(self):
284        """Stop the server.
285
286        This method causes the serve() method to return.  stop() may be invoked
287        from within your handler, or from another thread.
288
289        After stop() is called, serve() will return but the server will still
290        be listening on the socket.  serve() may then be called again to resume
291        processing requests.  Alternatively, close() may be called after
292        serve() returns to close the server socket and shutdown all worker
293        threads.
294        """
295        self._stop = True
296        self.wake_up()
297
298    def _select(self):
299        """Does select on open connections."""
300        readable = [self.socket.handle.fileno(), self._read.fileno()]
301        writable = []
302        remaining = []
303        for i, connection in list(self.clients.items()):
304            if connection.is_readable():
305                readable.append(connection.fileno())
306                if connection.remaining or connection.received:
307                    remaining.append(connection.fileno())
308            if connection.is_writeable():
309                writable.append(connection.fileno())
310            if connection.is_closed():
311                del self.clients[i]
312        if remaining:
313            return remaining, [], [], False
314        else:
315            return select.select(readable, writable, readable) + (True,)
316
317    def handle(self):
318        """Handle requests.
319
320        WARNING! You must call prepare() BEFORE calling handle()
321        """
322        assert self.prepared, "You have to call prepare before handle"
323        rset, wset, xset, selected = self._select()
324        for readable in rset:
325            if readable == self._read.fileno():
326                # don't care i just need to clean readable flag
327                self._read.recv(1024)
328            elif readable == self.socket.handle.fileno():
329                try:
330                    client = self.socket.accept()
331                    if client:
332                        self.clients[client.handle.fileno()] = Connection(client.handle,
333                                                                          self.wake_up)
334                except socket.error:
335                    logger.debug('error while accepting', exc_info=True)
336            else:
337                connection = self.clients[readable]
338                if selected:
339                    connection.read()
340                if connection.received:
341                    connection.status = WAIT_PROCESS
342                    msg = connection.received.popleft()
343                    itransport = TTransport.TMemoryBuffer(msg.buffer, msg.offset)
344                    otransport = TTransport.TMemoryBuffer()
345                    iprot = self.in_protocol.getProtocol(itransport)
346                    oprot = self.out_protocol.getProtocol(otransport)
347                    self.tasks.put([self.processor, iprot, oprot,
348                                    otransport, connection.ready])
349        for writeable in wset:
350            self.clients[writeable].write()
351        for oob in xset:
352            self.clients[oob].close()
353            del self.clients[oob]
354
355    def close(self):
356        """Closes the server."""
357        for _ in range(self.threads):
358            self.tasks.put([None, None, None, None, None])
359        self.socket.close()
360        self.prepared = False
361
362    def serve(self):
363        """Serve requests.
364
365        Serve requests forever, or until stop() is called.
366        """
367        self._stop = False
368        self.prepare()
369        while not self._stop:
370            self.handle()
371