1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from six.moves import queue
21import logging
22import os
23import threading
24
25from thrift.protocol import TBinaryProtocol
26from thrift.protocol.THeaderProtocol import THeaderProtocolFactory
27from thrift.transport import TTransport
28
29logger = logging.getLogger(__name__)
30
31
32class TServer(object):
33    """Base interface for a server, which must have a serve() method.
34
35    Three constructors for all servers:
36    1) (processor, serverTransport)
37    2) (processor, serverTransport, transportFactory, protocolFactory)
38    3) (processor, serverTransport,
39        inputTransportFactory, outputTransportFactory,
40        inputProtocolFactory, outputProtocolFactory)
41    """
42    def __init__(self, *args):
43        if (len(args) == 2):
44            self.__initArgs__(args[0], args[1],
45                              TTransport.TTransportFactoryBase(),
46                              TTransport.TTransportFactoryBase(),
47                              TBinaryProtocol.TBinaryProtocolFactory(),
48                              TBinaryProtocol.TBinaryProtocolFactory())
49        elif (len(args) == 4):
50            self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3])
51        elif (len(args) == 6):
52            self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5])
53
54    def __initArgs__(self, processor, serverTransport,
55                     inputTransportFactory, outputTransportFactory,
56                     inputProtocolFactory, outputProtocolFactory):
57        self.processor = processor
58        self.serverTransport = serverTransport
59        self.inputTransportFactory = inputTransportFactory
60        self.outputTransportFactory = outputTransportFactory
61        self.inputProtocolFactory = inputProtocolFactory
62        self.outputProtocolFactory = outputProtocolFactory
63
64        input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory)
65        output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory)
66        if any((input_is_header, output_is_header)) and input_is_header != output_is_header:
67            raise ValueError("THeaderProtocol servers require that both the input and "
68                             "output protocols are THeaderProtocol.")
69
70    def serve(self):
71        pass
72
73
74class TSimpleServer(TServer):
75    """Simple single-threaded server that just pumps around one transport."""
76
77    def __init__(self, *args):
78        TServer.__init__(self, *args)
79
80    def serve(self):
81        self.serverTransport.listen()
82        while True:
83            client = self.serverTransport.accept()
84            if not client:
85                continue
86
87            itrans = self.inputTransportFactory.getTransport(client)
88            iprot = self.inputProtocolFactory.getProtocol(itrans)
89
90            # for THeaderProtocol, we must use the same protocol instance for
91            # input and output so that the response is in the same dialect that
92            # the server detected the request was in.
93            if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
94                otrans = None
95                oprot = iprot
96            else:
97                otrans = self.outputTransportFactory.getTransport(client)
98                oprot = self.outputProtocolFactory.getProtocol(otrans)
99
100            try:
101                while True:
102                    self.processor.process(iprot, oprot)
103            except TTransport.TTransportException:
104                pass
105            except Exception as x:
106                logger.exception(x)
107
108            itrans.close()
109            if otrans:
110                otrans.close()
111
112
113class TThreadedServer(TServer):
114    """Threaded server that spawns a new thread per each connection."""
115
116    def __init__(self, *args, **kwargs):
117        TServer.__init__(self, *args)
118        self.daemon = kwargs.get("daemon", False)
119
120    def serve(self):
121        self.serverTransport.listen()
122        while True:
123            try:
124                client = self.serverTransport.accept()
125                if not client:
126                    continue
127                t = threading.Thread(target=self.handle, args=(client,))
128                t.setDaemon(self.daemon)
129                t.start()
130            except KeyboardInterrupt:
131                raise
132            except Exception as x:
133                logger.exception(x)
134
135    def handle(self, client):
136        itrans = self.inputTransportFactory.getTransport(client)
137        iprot = self.inputProtocolFactory.getProtocol(itrans)
138
139        # for THeaderProtocol, we must use the same protocol instance for input
140        # and output so that the response is in the same dialect that the
141        # server detected the request was in.
142        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
143            otrans = None
144            oprot = iprot
145        else:
146            otrans = self.outputTransportFactory.getTransport(client)
147            oprot = self.outputProtocolFactory.getProtocol(otrans)
148
149        try:
150            while True:
151                self.processor.process(iprot, oprot)
152        except TTransport.TTransportException:
153            pass
154        except Exception as x:
155            logger.exception(x)
156
157        itrans.close()
158        if otrans:
159            otrans.close()
160
161
162class TThreadPoolServer(TServer):
163    """Server with a fixed size pool of threads which service requests."""
164
165    def __init__(self, *args, **kwargs):
166        TServer.__init__(self, *args)
167        self.clients = queue.Queue()
168        self.threads = 10
169        self.daemon = kwargs.get("daemon", False)
170
171    def setNumThreads(self, num):
172        """Set the number of worker threads that should be created"""
173        self.threads = num
174
175    def serveThread(self):
176        """Loop around getting clients from the shared queue and process them."""
177        while True:
178            try:
179                client = self.clients.get()
180                self.serveClient(client)
181            except Exception as x:
182                logger.exception(x)
183
184    def serveClient(self, client):
185        """Process input/output from a client for as long as possible"""
186        itrans = self.inputTransportFactory.getTransport(client)
187        iprot = self.inputProtocolFactory.getProtocol(itrans)
188
189        # for THeaderProtocol, we must use the same protocol instance for input
190        # and output so that the response is in the same dialect that the
191        # server detected the request was in.
192        if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
193            otrans = None
194            oprot = iprot
195        else:
196            otrans = self.outputTransportFactory.getTransport(client)
197            oprot = self.outputProtocolFactory.getProtocol(otrans)
198
199        try:
200            while True:
201                self.processor.process(iprot, oprot)
202        except TTransport.TTransportException:
203            pass
204        except Exception as x:
205            logger.exception(x)
206
207        itrans.close()
208        if otrans:
209            otrans.close()
210
211    def serve(self):
212        """Start a fixed number of worker threads and put client into a queue"""
213        for i in range(self.threads):
214            try:
215                t = threading.Thread(target=self.serveThread)
216                t.setDaemon(self.daemon)
217                t.start()
218            except Exception as x:
219                logger.exception(x)
220
221        # Pump the socket for clients
222        self.serverTransport.listen()
223        while True:
224            try:
225                client = self.serverTransport.accept()
226                if not client:
227                    continue
228                self.clients.put(client)
229            except Exception as x:
230                logger.exception(x)
231
232
233class TForkingServer(TServer):
234    """A Thrift server that forks a new process for each request
235
236    This is more scalable than the threaded server as it does not cause
237    GIL contention.
238
239    Note that this has different semantics from the threading server.
240    Specifically, updates to shared variables will no longer be shared.
241    It will also not work on windows.
242
243    This code is heavily inspired by SocketServer.ForkingMixIn in the
244    Python stdlib.
245    """
246    def __init__(self, *args):
247        TServer.__init__(self, *args)
248        self.children = []
249
250    def serve(self):
251        def try_close(file):
252            try:
253                file.close()
254            except IOError as e:
255                logger.warning(e, exc_info=True)
256
257        self.serverTransport.listen()
258        while True:
259            client = self.serverTransport.accept()
260            if not client:
261                continue
262            try:
263                pid = os.fork()
264
265                if pid:  # parent
266                    # add before collect, otherwise you race w/ waitpid
267                    self.children.append(pid)
268                    self.collect_children()
269
270                    # Parent must close socket or the connection may not get
271                    # closed promptly
272                    itrans = self.inputTransportFactory.getTransport(client)
273                    otrans = self.outputTransportFactory.getTransport(client)
274                    try_close(itrans)
275                    try_close(otrans)
276                else:
277                    itrans = self.inputTransportFactory.getTransport(client)
278                    iprot = self.inputProtocolFactory.getProtocol(itrans)
279
280                    # for THeaderProtocol, we must use the same protocol
281                    # instance for input and output so that the response is in
282                    # the same dialect that the server detected the request was
283                    # in.
284                    if isinstance(self.inputProtocolFactory, THeaderProtocolFactory):
285                        otrans = None
286                        oprot = iprot
287                    else:
288                        otrans = self.outputTransportFactory.getTransport(client)
289                        oprot = self.outputProtocolFactory.getProtocol(otrans)
290
291                    ecode = 0
292                    try:
293                        try:
294                            while True:
295                                self.processor.process(iprot, oprot)
296                        except TTransport.TTransportException:
297                            pass
298                        except Exception as e:
299                            logger.exception(e)
300                            ecode = 1
301                    finally:
302                        try_close(itrans)
303                        if otrans:
304                            try_close(otrans)
305
306                    os._exit(ecode)
307
308            except TTransport.TTransportException:
309                pass
310            except Exception as x:
311                logger.exception(x)
312
313    def collect_children(self):
314        while self.children:
315            try:
316                pid, status = os.waitpid(0, os.WNOHANG)
317            except os.error:
318                pid = None
319
320            if pid:
321                self.children.remove(pid)
322            else:
323                break
324