1import os
2import sys
3import ssl
4import pprint
5import threading
6import urllib.parse
7# Rename HTTPServer to _HTTPServer so as to avoid confusion with HTTPSServer.
8from http.server import (HTTPServer as _HTTPServer,
9    SimpleHTTPRequestHandler, BaseHTTPRequestHandler)
10
11from test import support
12from test.support import socket_helper
13
14here = os.path.dirname(__file__)
15
16HOST = socket_helper.HOST
17CERTFILE = os.path.join(here, 'keycert.pem')
18
19# This one's based on HTTPServer, which is based on socketserver
20
21class HTTPSServer(_HTTPServer):
22
23    def __init__(self, server_address, handler_class, context):
24        _HTTPServer.__init__(self, server_address, handler_class)
25        self.context = context
26
27    def __str__(self):
28        return ('<%s %s:%s>' %
29                (self.__class__.__name__,
30                 self.server_name,
31                 self.server_port))
32
33    def get_request(self):
34        # override this to wrap socket with SSL
35        try:
36            sock, addr = self.socket.accept()
37            sslconn = self.context.wrap_socket(sock, server_side=True)
38        except OSError as e:
39            # socket errors are silenced by the caller, print them here
40            if support.verbose:
41                sys.stderr.write("Got an error:\n%s\n" % e)
42            raise
43        return sslconn, addr
44
45class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
46    # need to override translate_path to get a known root,
47    # instead of using os.curdir, since the test could be
48    # run from anywhere
49
50    server_version = "TestHTTPS/1.0"
51    root = here
52    # Avoid hanging when a request gets interrupted by the client
53    timeout = support.LOOPBACK_TIMEOUT
54
55    def translate_path(self, path):
56        """Translate a /-separated PATH to the local filename syntax.
57
58        Components that mean special things to the local file system
59        (e.g. drive or directory names) are ignored.  (XXX They should
60        probably be diagnosed.)
61
62        """
63        # abandon query parameters
64        path = urllib.parse.urlparse(path)[2]
65        path = os.path.normpath(urllib.parse.unquote(path))
66        words = path.split('/')
67        words = filter(None, words)
68        path = self.root
69        for word in words:
70            drive, word = os.path.splitdrive(word)
71            head, word = os.path.split(word)
72            path = os.path.join(path, word)
73        return path
74
75    def log_message(self, format, *args):
76        # we override this to suppress logging unless "verbose"
77        if support.verbose:
78            sys.stdout.write(" server (%s:%d %s):\n   [%s] %s\n" %
79                             (self.server.server_address,
80                              self.server.server_port,
81                              self.request.cipher(),
82                              self.log_date_time_string(),
83                              format%args))
84
85
86class StatsRequestHandler(BaseHTTPRequestHandler):
87    """Example HTTP request handler which returns SSL statistics on GET
88    requests.
89    """
90
91    server_version = "StatsHTTPS/1.0"
92
93    def do_GET(self, send_body=True):
94        """Serve a GET request."""
95        sock = self.rfile.raw._sock
96        context = sock.context
97        stats = {
98            'session_cache': context.session_stats(),
99            'cipher': sock.cipher(),
100            'compression': sock.compression(),
101            }
102        body = pprint.pformat(stats)
103        body = body.encode('utf-8')
104        self.send_response(200)
105        self.send_header("Content-type", "text/plain; charset=utf-8")
106        self.send_header("Content-Length", str(len(body)))
107        self.end_headers()
108        if send_body:
109            self.wfile.write(body)
110
111    def do_HEAD(self):
112        """Serve a HEAD request."""
113        self.do_GET(send_body=False)
114
115    def log_request(self, format, *args):
116        if support.verbose:
117            BaseHTTPRequestHandler.log_request(self, format, *args)
118
119
120class HTTPSServerThread(threading.Thread):
121
122    def __init__(self, context, host=HOST, handler_class=None):
123        self.flag = None
124        self.server = HTTPSServer((host, 0),
125                                  handler_class or RootedHTTPRequestHandler,
126                                  context)
127        self.port = self.server.server_port
128        threading.Thread.__init__(self)
129        self.daemon = True
130
131    def __str__(self):
132        return "<%s %s>" % (self.__class__.__name__, self.server)
133
134    def start(self, flag=None):
135        self.flag = flag
136        threading.Thread.start(self)
137
138    def run(self):
139        if self.flag:
140            self.flag.set()
141        try:
142            self.server.serve_forever(0.05)
143        finally:
144            self.server.server_close()
145
146    def stop(self):
147        self.server.shutdown()
148
149
150def make_https_server(case, *, context=None, certfile=CERTFILE,
151                      host=HOST, handler_class=None):
152    if context is None:
153        context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
154    # We assume the certfile contains both private key and certificate
155    context.load_cert_chain(certfile)
156    server = HTTPSServerThread(context, host, handler_class)
157    flag = threading.Event()
158    server.start(flag)
159    flag.wait()
160    def cleanup():
161        if support.verbose:
162            sys.stdout.write('stopping HTTPS server\n')
163        server.stop()
164        if support.verbose:
165            sys.stdout.write('joining HTTPS thread\n')
166        server.join()
167    case.addCleanup(cleanup)
168    return server
169
170
171if __name__ == "__main__":
172    import argparse
173    parser = argparse.ArgumentParser(
174        description='Run a test HTTPS server. '
175                    'By default, the current directory is served.')
176    parser.add_argument('-p', '--port', type=int, default=4433,
177                        help='port to listen on (default: %(default)s)')
178    parser.add_argument('-q', '--quiet', dest='verbose', default=True,
179                        action='store_false', help='be less verbose')
180    parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
181                        action='store_true', help='always return stats page')
182    parser.add_argument('--curve-name', dest='curve_name', type=str,
183                        action='store',
184                        help='curve name for EC-based Diffie-Hellman')
185    parser.add_argument('--ciphers', dest='ciphers', type=str,
186                        help='allowed cipher list')
187    parser.add_argument('--dh', dest='dh_file', type=str, action='store',
188                        help='PEM file containing DH parameters')
189    args = parser.parse_args()
190
191    support.verbose = args.verbose
192    if args.use_stats_handler:
193        handler_class = StatsRequestHandler
194    else:
195        handler_class = RootedHTTPRequestHandler
196        handler_class.root = os.getcwd()
197    context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
198    context.load_cert_chain(CERTFILE)
199    if args.curve_name:
200        context.set_ecdh_curve(args.curve_name)
201    if args.dh_file:
202        context.load_dh_params(args.dh_file)
203    if args.ciphers:
204        context.set_ciphers(args.ciphers)
205
206    server = HTTPSServer(("", args.port), handler_class, context)
207    if args.verbose:
208        print("Listening on https://localhost:{0.port}".format(args))
209    server.serve_forever(0.1)
210