1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20
21import logging
22
23from multiprocessing import Process, Value, Condition
24
25from .TServer import TServer
26from thrift.transport.TTransport import TTransportException
27
28logger = logging.getLogger(__name__)
29
30
31class TProcessPoolServer(TServer):
32    """Server with a fixed size pool of worker subprocesses to service requests
33
34    Note that if you need shared state between the handlers - it's up to you!
35    Written by Dvir Volk, doat.com
36    """
37    def __init__(self, *args):
38        TServer.__init__(self, *args)
39        self.numWorkers = 10
40        self.workers = []
41        self.isRunning = Value('b', False)
42        self.stopCondition = Condition()
43        self.postForkCallback = None
44
45    def setPostForkCallback(self, callback):
46        if not callable(callback):
47            raise TypeError("This is not a callback!")
48        self.postForkCallback = callback
49
50    def setNumWorkers(self, num):
51        """Set the number of worker threads that should be created"""
52        self.numWorkers = num
53
54    def workerProcess(self):
55        """Loop getting clients from the shared queue and process them"""
56        if self.postForkCallback:
57            self.postForkCallback()
58
59        while self.isRunning.value:
60            try:
61                client = self.serverTransport.accept()
62                if not client:
63                    continue
64                self.serveClient(client)
65            except (KeyboardInterrupt, SystemExit):
66                return 0
67            except Exception as x:
68                logger.exception(x)
69
70    def serveClient(self, client):
71        """Process input/output from a client for as long as possible"""
72        itrans = self.inputTransportFactory.getTransport(client)
73        otrans = self.outputTransportFactory.getTransport(client)
74        iprot = self.inputProtocolFactory.getProtocol(itrans)
75        oprot = self.outputProtocolFactory.getProtocol(otrans)
76
77        try:
78            while True:
79                self.processor.process(iprot, oprot)
80        except TTransportException:
81            pass
82        except Exception as x:
83            logger.exception(x)
84
85        itrans.close()
86        otrans.close()
87
88    def serve(self):
89        """Start workers and put into queue"""
90        # this is a shared state that can tell the workers to exit when False
91        self.isRunning.value = True
92
93        # first bind and listen to the port
94        self.serverTransport.listen()
95
96        # fork the children
97        for i in range(self.numWorkers):
98            try:
99                w = Process(target=self.workerProcess)
100                w.daemon = True
101                w.start()
102                self.workers.append(w)
103            except Exception as x:
104                logger.exception(x)
105
106        # wait until the condition is set by stop()
107        while True:
108            self.stopCondition.acquire()
109            try:
110                self.stopCondition.wait()
111                break
112            except (SystemExit, KeyboardInterrupt):
113                break
114            except Exception as x:
115                logger.exception(x)
116
117        self.isRunning.value = False
118
119    def stop(self):
120        self.isRunning.value = False
121        self.stopCondition.acquire()
122        self.stopCondition.notify()
123        self.stopCondition.release()
124