1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20import ssl
21
22from six.moves import BaseHTTPServer
23
24from thrift.Thrift import TMessageType
25from thrift.server import TServer
26from thrift.transport import TTransport
27
28
29class ResponseException(Exception):
30    """Allows handlers to override the HTTP response
31
32    Normally, THttpServer always sends a 200 response.  If a handler wants
33    to override this behavior (e.g., to simulate a misconfigured or
34    overloaded web server during testing), it can raise a ResponseException.
35    The function passed to the constructor will be called with the
36    RequestHandler as its only argument.  Note that this is irrelevant
37    for ONEWAY requests, as the HTTP response must be sent before the
38    RPC is processed.
39    """
40    def __init__(self, handler):
41        self.handler = handler
42
43
44class THttpServer(TServer.TServer):
45    """A simple HTTP-based Thrift server
46
47    This class is not very performant, but it is useful (for example) for
48    acting as a mock version of an Apache-based PHP Thrift endpoint.
49    Also important to note the HTTP implementation pretty much violates the
50    transport/protocol/processor/server layering, by performing the transport
51    functions here.  This means things like oneway handling are oddly exposed.
52    """
53    def __init__(self,
54                 processor,
55                 server_address,
56                 inputProtocolFactory,
57                 outputProtocolFactory=None,
58                 server_class=BaseHTTPServer.HTTPServer,
59                 **kwargs):
60        """Set up protocol factories and HTTP (or HTTPS) server.
61
62        See BaseHTTPServer for server_address.
63        See TServer for protocol factories.
64
65        To make a secure server, provide the named arguments:
66        * cafile    - to validate clients [optional]
67        * cert_file - the server cert
68        * key_file  - the server's key
69        """
70        if outputProtocolFactory is None:
71            outputProtocolFactory = inputProtocolFactory
72
73        TServer.TServer.__init__(self, processor, None, None, None,
74                                 inputProtocolFactory, outputProtocolFactory)
75
76        thttpserver = self
77        self._replied = None
78
79        class RequestHander(BaseHTTPServer.BaseHTTPRequestHandler):
80            def do_POST(self):
81                # Don't care about the request path.
82                thttpserver._replied = False
83                iftrans = TTransport.TFileObjectTransport(self.rfile)
84                itrans = TTransport.TBufferedTransport(
85                    iftrans, int(self.headers['Content-Length']))
86                otrans = TTransport.TMemoryBuffer()
87                iprot = thttpserver.inputProtocolFactory.getProtocol(itrans)
88                oprot = thttpserver.outputProtocolFactory.getProtocol(otrans)
89                try:
90                    thttpserver.processor.on_message_begin(self.on_begin)
91                    thttpserver.processor.process(iprot, oprot)
92                except ResponseException as exn:
93                    exn.handler(self)
94                else:
95                    if not thttpserver._replied:
96                        # If the request was ONEWAY we would have replied already
97                        data = otrans.getvalue()
98                        self.send_response(200)
99                        self.send_header("Content-Length", len(data))
100                        self.send_header("Content-Type", "application/x-thrift")
101                        self.end_headers()
102                        self.wfile.write(data)
103
104            def on_begin(self, name, type, seqid):
105                """
106                Inspect the message header.
107
108                This allows us to post an immediate transport response
109                if the request is a ONEWAY message type.
110                """
111                if type == TMessageType.ONEWAY:
112                    self.send_response(200)
113                    self.send_header("Content-Type", "application/x-thrift")
114                    self.end_headers()
115                    thttpserver._replied = True
116
117        self.httpd = server_class(server_address, RequestHander)
118
119        if (kwargs.get('cafile') or kwargs.get('cert_file') or kwargs.get('key_file')):
120            context = ssl.create_default_context(cafile=kwargs.get('cafile'))
121            context.check_hostname = False
122            context.load_cert_chain(kwargs.get('cert_file'), kwargs.get('key_file'))
123            context.verify_mode = ssl.CERT_REQUIRED if kwargs.get('cafile') else ssl.CERT_NONE
124            self.httpd.socket = context.wrap_socket(self.httpd.socket, server_side=True)
125
126    def serve(self):
127        self.httpd.serve_forever()
128
129    def shutdown(self):
130        self.httpd.socket.close()
131        # self.httpd.shutdown() # hangs forever, python doesn't handle POLLNVAL properly!
132