1# 2# Licensed to the Apache Software Foundation (ASF) under one 3# or more contributor license agreements. See the NOTICE file 4# distributed with this work for additional information 5# regarding copyright ownership. The ASF licenses this file 6# to you under the Apache License, Version 2.0 (the 7# "License"); you may not use this file except in compliance 8# with the License. You may obtain a copy of the License at 9# 10# http://www.apache.org/licenses/LICENSE-2.0 11# 12# Unless required by applicable law or agreed to in writing, 13# software distributed under the License is distributed on an 14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15# KIND, either express or implied. See the License for the 16# specific language governing permissions and limitations 17# under the License. 18# 19 20import ssl 21 22from six.moves import BaseHTTPServer 23 24from thrift.Thrift import TMessageType 25from thrift.server import TServer 26from thrift.transport import TTransport 27 28 29class ResponseException(Exception): 30 """Allows handlers to override the HTTP response 31 32 Normally, THttpServer always sends a 200 response. If a handler wants 33 to override this behavior (e.g., to simulate a misconfigured or 34 overloaded web server during testing), it can raise a ResponseException. 35 The function passed to the constructor will be called with the 36 RequestHandler as its only argument. Note that this is irrelevant 37 for ONEWAY requests, as the HTTP response must be sent before the 38 RPC is processed. 39 """ 40 def __init__(self, handler): 41 self.handler = handler 42 43 44class THttpServer(TServer.TServer): 45 """A simple HTTP-based Thrift server 46 47 This class is not very performant, but it is useful (for example) for 48 acting as a mock version of an Apache-based PHP Thrift endpoint. 49 Also important to note the HTTP implementation pretty much violates the 50 transport/protocol/processor/server layering, by performing the transport 51 functions here. This means things like oneway handling are oddly exposed. 52 """ 53 def __init__(self, 54 processor, 55 server_address, 56 inputProtocolFactory, 57 outputProtocolFactory=None, 58 server_class=BaseHTTPServer.HTTPServer, 59 **kwargs): 60 """Set up protocol factories and HTTP (or HTTPS) server. 61 62 See BaseHTTPServer for server_address. 63 See TServer for protocol factories. 64 65 To make a secure server, provide the named arguments: 66 * cafile - to validate clients [optional] 67 * cert_file - the server cert 68 * key_file - the server's key 69 """ 70 if outputProtocolFactory is None: 71 outputProtocolFactory = inputProtocolFactory 72 73 TServer.TServer.__init__(self, processor, None, None, None, 74 inputProtocolFactory, outputProtocolFactory) 75 76 thttpserver = self 77 self._replied = None 78 79 class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler): 80 def do_POST(self): 81 # Don't care about the request path. 82 thttpserver._replied = False 83 iftrans = TTransport.TFileObjectTransport(self.rfile) 84 itrans = TTransport.TBufferedTransport( 85 iftrans, int(self.headers['Content-Length'])) 86 otrans = TTransport.TMemoryBuffer() 87 iprot = thttpserver.inputProtocolFactory.getProtocol(itrans) 88 oprot = thttpserver.outputProtocolFactory.getProtocol(otrans) 89 try: 90 thttpserver.processor.on_message_begin(self.on_begin) 91 thttpserver.processor.process(iprot, oprot) 92 except ResponseException as exn: 93 exn.handler(self) 94 else: 95 if not thttpserver._replied: 96 # If the request was ONEWAY we would have replied already 97 data = otrans.getvalue() 98 self.send_response(200) 99 self.send_header("Content-Length", len(data)) 100 self.send_header("Content-Type", "application/x-thrift") 101 self.end_headers() 102 self.wfile.write(data) 103 104 def on_begin(self, name, type, seqid): 105 """ 106 Inspect the message header. 107 108 This allows us to post an immediate transport response 109 if the request is a ONEWAY message type. 110 """ 111 if type == TMessageType.ONEWAY: 112 self.send_response(200) 113 self.send_header("Content-Type", "application/x-thrift") 114 self.end_headers() 115 thttpserver._replied = True 116 117 self.httpd = server_class(server_address, RequestHander) 118 119 if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')): 120 context = ssl.create_default_context(cafile=kwargs.get('cafile')) 121 context.check_hostname = False 122 context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file')) 123 context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE 124 self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True) 125 126 def serve(self): 127 self.httpd.serve_forever() 128 129 def shutdown(self): 130 self.httpd.socket.close() 131 # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly! 132