1from __future__ import absolute_import, division, print_function, unicode_literals
2from future.builtins import filter, str
3from future import utils
4import os
5import sys
6import ssl
7import pprint
8import socket
9from future.backports.urllib import parse as urllib_parse
10from future.backports.http.server import (HTTPServer as _HTTPServer,
11    SimpleHTTPRequestHandler, BaseHTTPRequestHandler)
12from future.backports.test import support
13threading = support.import_module("threading")
14
15here = os.path.dirname(__file__)
16
17HOST = support.HOST
18CERTFILE = os.path.join(here, 'keycert.pem')
19
20# This one's based on HTTPServer, which is based on SocketServer
21
22class HTTPSServer(_HTTPServer):
23
24    def __init__(self, server_address, handler_class, context):
25        _HTTPServer.__init__(self, server_address, handler_class)
26        self.context = context
27
28    def __str__(self):
29        return ('<%s %s:%s>' %
30                (self.__class__.__name__,
31                 self.server_name,
32                 self.server_port))
33
34    def get_request(self):
35        # override this to wrap socket with SSL
36        try:
37            sock, addr = self.socket.accept()
38            sslconn = self.context.wrap_socket(sock, server_side=True)
39        except socket.error as e:
40            # socket errors are silenced by the caller, print them here
41            if support.verbose:
42                sys.stderr.write("Got an error:\n%s\n" % e)
43            raise
44        return sslconn, addr
45
46class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
47    # need to override translate_path to get a known root,
48    # instead of using os.curdir, since the test could be
49    # run from anywhere
50
51    server_version = "TestHTTPS/1.0"
52    root = here
53    # Avoid hanging when a request gets interrupted by the client
54    timeout = 5
55
56    def translate_path(self, path):
57        """Translate a /-separated PATH to the local filename syntax.
58
59        Components that mean special things to the local file system
60        (e.g. drive or directory names) are ignored.  (XXX They should
61        probably be diagnosed.)
62
63        """
64        # abandon query parameters
65        path = urllib.parse.urlparse(path)[2]
66        path = os.path.normpath(urllib.parse.unquote(path))
67        words = path.split('/')
68        words = filter(None, words)
69        path = self.root
70        for word in words:
71            drive, word = os.path.splitdrive(word)
72            head, word = os.path.split(word)
73            path = os.path.join(path, word)
74        return path
75
76    def log_message(self, format, *args):
77        # we override this to suppress logging unless "verbose"
78        if support.verbose:
79            sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
80                             (self.server.server_address,
81                              self.server.server_port,
82                              self.request.cipher(),
83                              self.log_date_time_string(),
84                              format%args))
85
86
87class StatsRequestHandler(BaseHTTPRequestHandler):
88    """Example HTTP request handler which returns SSL statistics on GET
89    requests.
90    """
91
92    server_version = "StatsHTTPS/1.0"
93
94    def do_GET(self, send_body=True):
95        """Serve a GET request."""
96        sock = self.rfile.raw._sock
97        context = sock.context
98        stats = {
99            'session_cache': context.session_stats(),
100            'cipher': sock.cipher(),
101            'compression': sock.compression(),
102            }
103        body = pprint.pformat(stats)
104        body = body.encode('utf-8')
105        self.send_response(200)
106        self.send_header("Content-type", "text/plain; charset=utf-8")
107        self.send_header("Content-Length", str(len(body)))
108        self.end_headers()
109        if send_body:
110            self.wfile.write(body)
111
112    def do_HEAD(self):
113        """Serve a HEAD request."""
114        self.do_GET(send_body=False)
115
116    def log_request(self, format, *args):
117        if support.verbose:
118            BaseHTTPRequestHandler.log_request(self, format, *args)
119
120
121class HTTPSServerThread(threading.Thread):
122
123    def __init__(self, context, host=HOST, handler_class=None):
124        self.flag = None
125        self.server = HTTPSServer((host, 0),
126                                  handler_class or RootedHTTPRequestHandler,
127                                  context)
128        self.port = self.server.server_port
129        threading.Thread.__init__(self)
130        self.daemon = True
131
132    def __str__(self):
133        return "<%s %s>" % (self.__class__.__name__, self.server)
134
135    def start(self, flag=None):
136        self.flag = flag
137        threading.Thread.start(self)
138
139    def run(self):
140        if self.flag:
141            self.flag.set()
142        try:
143            self.server.serve_forever(0.05)
144        finally:
145            self.server.server_close()
146
147    def stop(self):
148        self.server.shutdown()
149
150
151def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None):
152    # we assume the certfile contains both private key and certificate
153    context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
154    context.load_cert_chain(certfile)
155    server = HTTPSServerThread(context, host, handler_class)
156    flag = threading.Event()
157    server.start(flag)
158    flag.wait()
159    def cleanup():
160        if support.verbose:
161            sys.stdout.write('stopping HTTPS server\n')
162        server.stop()
163        if support.verbose:
164            sys.stdout.write('joining HTTPS thread\n')
165        server.join()
166    case.addCleanup(cleanup)
167    return server
168
169
170if __name__ == "__main__":
171    import argparse
172    parser = argparse.ArgumentParser(
173        description='Run a test HTTPS server. '
174                    'By default, the current directory is served.')
175    parser.add_argument('-p', '--port', type=int, default=4433,
176                        help='port to listen on (default: %(default)s)')
177    parser.add_argument('-q', '--quiet', dest='verbose', default=True,
178                        action='store_false', help='be less verbose')
179    parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
180                        action='store_true', help='always return stats page')
181    parser.add_argument('--curve-name', dest='curve_name', type=str,
182                        action='store',
183                        help='curve name for EC-based Diffie-Hellman')
184    parser.add_argument('--dh', dest='dh_file', type=str, action='store',
185                        help='PEM file containing DH parameters')
186    args = parser.parse_args()
187
188    support.verbose = args.verbose
189    if args.use_stats_handler:
190        handler_class = StatsRequestHandler
191    else:
192        handler_class = RootedHTTPRequestHandler
193        if utils.PY2:
194            handler_class.root = os.getcwdu()
195        else:
196            handler_class.root = os.getcwd()
197    context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
198    context.load_cert_chain(CERTFILE)
199    if args.curve_name:
200        context.set_ecdh_curve(args.curve_name)
201    if args.dh_file:
202        context.load_dh_params(args.dh_file)
203
204    server = HTTPSServer(("", args.port), handler_class, context)
205    if args.verbose:
206        print("Listening on https://localhost:{0.port}".format(args))
207    server.serve_forever(0.1)
208