1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20from six.moves import queue 21import logging 22import os 23import threading 24 25from thrift.protocol import TBinaryProtocol 26from thrift.protocol.THeaderProtocol import THeaderProtocolFactory 27from thrift.transport import TTransport 28 29logger = logging.getLogger(__name__) 30 31 32class TServer(object): 33 """Base interface for a server, which must have a serve() method. 34 35 Three constructors for all servers: 36 1) (processor, serverTransport) 37 2) (processor, serverTransport, transportFactory, protocolFactory) 38 3) (processor, serverTransport, 39 inputTransportFactory, outputTransportFactory, 40 inputProtocolFactory, outputProtocolFactory) 41 """ 42 def __init__(self, *args): 43 if (len(args) == 2): 44 self.__initArgs__(args[0], args[1], 45 TTransport.TTransportFactoryBase(), 46 TTransport.TTransportFactoryBase(), 47 TBinaryProtocol.TBinaryProtocolFactory(), 48 TBinaryProtocol.TBinaryProtocolFactory()) 49 elif (len(args) == 4): 50 self.__initArgs__(args[0], args[1], args[2], args[2], args[3], args[3]) 51 elif (len(args) == 6): 52 self.__initArgs__(args[0], args[1], args[2], args[3], args[4], args[5]) 53 54 def __initArgs__(self, processor, serverTransport, 55 inputTransportFactory, outputTransportFactory, 56 inputProtocolFactory, outputProtocolFactory): 57 self.processor = processor 58 self.serverTransport = serverTransport 59 self.inputTransportFactory = inputTransportFactory 60 self.outputTransportFactory = outputTransportFactory 61 self.inputProtocolFactory = inputProtocolFactory 62 self.outputProtocolFactory = outputProtocolFactory 63 64 input_is_header = isinstance(self.inputProtocolFactory, THeaderProtocolFactory) 65 output_is_header = isinstance(self.outputProtocolFactory, THeaderProtocolFactory) 66 if any((input_is_header, output_is_header)) and input_is_header != output_is_header: 67 raise ValueError("THeaderProtocol servers require that both the input and " 68 "output protocols are THeaderProtocol.") 69 70 def serve(self): 71 pass 72 73 74class TSimpleServer(TServer): 75 """Simple single-threaded server that just pumps around one transport.""" 76 77 def __init__(self, *args): 78 TServer.__init__(self, *args) 79 80 def serve(self): 81 self.serverTransport.listen() 82 while True: 83 client = self.serverTransport.accept() 84 if not client: 85 continue 86 87 itrans = self.inputTransportFactory.getTransport(client) 88 iprot = self.inputProtocolFactory.getProtocol(itrans) 89 90 # for THeaderProtocol, we must use the same protocol instance for 91 # input and output so that the response is in the same dialect that 92 # the server detected the request was in. 93 if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): 94 otrans = None 95 oprot = iprot 96 else: 97 otrans = self.outputTransportFactory.getTransport(client) 98 oprot = self.outputProtocolFactory.getProtocol(otrans) 99 100 try: 101 while True: 102 self.processor.process(iprot, oprot) 103 except TTransport.TTransportException: 104 pass 105 except Exception as x: 106 logger.exception(x) 107 108 itrans.close() 109 if otrans: 110 otrans.close() 111 112 113class TThreadedServer(TServer): 114 """Threaded server that spawns a new thread per each connection.""" 115 116 def __init__(self, *args, **kwargs): 117 TServer.__init__(self, *args) 118 self.daemon = kwargs.get("daemon", False) 119 120 def serve(self): 121 self.serverTransport.listen() 122 while True: 123 try: 124 client = self.serverTransport.accept() 125 if not client: 126 continue 127 t = threading.Thread(target=self.handle, args=(client,)) 128 t.setDaemon(self.daemon) 129 t.start() 130 except KeyboardInterrupt: 131 raise 132 except Exception as x: 133 logger.exception(x) 134 135 def handle(self, client): 136 itrans = self.inputTransportFactory.getTransport(client) 137 iprot = self.inputProtocolFactory.getProtocol(itrans) 138 139 # for THeaderProtocol, we must use the same protocol instance for input 140 # and output so that the response is in the same dialect that the 141 # server detected the request was in. 142 if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): 143 otrans = None 144 oprot = iprot 145 else: 146 otrans = self.outputTransportFactory.getTransport(client) 147 oprot = self.outputProtocolFactory.getProtocol(otrans) 148 149 try: 150 while True: 151 self.processor.process(iprot, oprot) 152 except TTransport.TTransportException: 153 pass 154 except Exception as x: 155 logger.exception(x) 156 157 itrans.close() 158 if otrans: 159 otrans.close() 160 161 162class TThreadPoolServer(TServer): 163 """Server with a fixed size pool of threads which service requests.""" 164 165 def __init__(self, *args, **kwargs): 166 TServer.__init__(self, *args) 167 self.clients = queue.Queue() 168 self.threads = 10 169 self.daemon = kwargs.get("daemon", False) 170 171 def setNumThreads(self, num): 172 """Set the number of worker threads that should be created""" 173 self.threads = num 174 175 def serveThread(self): 176 """Loop around getting clients from the shared queue and process them.""" 177 while True: 178 try: 179 client = self.clients.get() 180 self.serveClient(client) 181 except Exception as x: 182 logger.exception(x) 183 184 def serveClient(self, client): 185 """Process input/output from a client for as long as possible""" 186 itrans = self.inputTransportFactory.getTransport(client) 187 iprot = self.inputProtocolFactory.getProtocol(itrans) 188 189 # for THeaderProtocol, we must use the same protocol instance for input 190 # and output so that the response is in the same dialect that the 191 # server detected the request was in. 192 if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): 193 otrans = None 194 oprot = iprot 195 else: 196 otrans = self.outputTransportFactory.getTransport(client) 197 oprot = self.outputProtocolFactory.getProtocol(otrans) 198 199 try: 200 while True: 201 self.processor.process(iprot, oprot) 202 except TTransport.TTransportException: 203 pass 204 except Exception as x: 205 logger.exception(x) 206 207 itrans.close() 208 if otrans: 209 otrans.close() 210 211 def serve(self): 212 """Start a fixed number of worker threads and put client into a queue""" 213 for i in range(self.threads): 214 try: 215 t = threading.Thread(target=self.serveThread) 216 t.setDaemon(self.daemon) 217 t.start() 218 except Exception as x: 219 logger.exception(x) 220 221 # Pump the socket for clients 222 self.serverTransport.listen() 223 while True: 224 try: 225 client = self.serverTransport.accept() 226 if not client: 227 continue 228 self.clients.put(client) 229 except Exception as x: 230 logger.exception(x) 231 232 233class TForkingServer(TServer): 234 """A Thrift server that forks a new process for each request 235 236 This is more scalable than the threaded server as it does not cause 237 GIL contention. 238 239 Note that this has different semantics from the threading server. 240 Specifically, updates to shared variables will no longer be shared. 241 It will also not work on windows. 242 243 This code is heavily inspired by SocketServer.ForkingMixIn in the 244 Python stdlib. 245 """ 246 def __init__(self, *args): 247 TServer.__init__(self, *args) 248 self.children = [] 249 250 def serve(self): 251 def try_close(file): 252 try: 253 file.close() 254 except IOError as e: 255 logger.warning(e, exc_info=True) 256 257 self.serverTransport.listen() 258 while True: 259 client = self.serverTransport.accept() 260 if not client: 261 continue 262 try: 263 pid = os.fork() 264 265 if pid: # parent 266 # add before collect, otherwise you race w/ waitpid 267 self.children.append(pid) 268 self.collect_children() 269 270 # Parent must close socket or the connection may not get 271 # closed promptly 272 itrans = self.inputTransportFactory.getTransport(client) 273 otrans = self.outputTransportFactory.getTransport(client) 274 try_close(itrans) 275 try_close(otrans) 276 else: 277 itrans = self.inputTransportFactory.getTransport(client) 278 iprot = self.inputProtocolFactory.getProtocol(itrans) 279 280 # for THeaderProtocol, we must use the same protocol 281 # instance for input and output so that the response is in 282 # the same dialect that the server detected the request was 283 # in. 284 if isinstance(self.inputProtocolFactory, THeaderProtocolFactory): 285 otrans = None 286 oprot = iprot 287 else: 288 otrans = self.outputTransportFactory.getTransport(client) 289 oprot = self.outputProtocolFactory.getProtocol(otrans) 290 291 ecode = 0 292 try: 293 try: 294 while True: 295 self.processor.process(iprot, oprot) 296 except TTransport.TTransportException: 297 pass 298 except Exception as e: 299 logger.exception(e) 300 ecode = 1 301 finally: 302 try_close(itrans) 303 if otrans: 304 try_close(otrans) 305 306 os._exit(ecode) 307 308 except TTransport.TTransportException: 309 pass 310 except Exception as x: 311 logger.exception(x) 312 313 def collect_children(self): 314 while self.children: 315 try: 316 pid, status = os.waitpid(0, os.WNOHANG) 317 except os.error: 318 pid = None 319 320 if pid: 321 self.children.remove(pid) 322 else: 323 break 324