1from __future__ import absolute_import, division, print_function, unicode_literals 2from future.builtins import filter, str 3from future import utils 4import os 5import sys 6import ssl 7import pprint 8import socket 9from future.backports.urllib import parse as urllib_parse 10from future.backports.http.server import (HTTPServer as _HTTPServer, 11 SimpleHTTPRequestHandler, BaseHTTPRequestHandler) 12from future.backports.test import support 13threading = support.import_module("threading") 14 15here = os.path.dirname(__file__) 16 17HOST = support.HOST 18CERTFILE = os.path.join(here, 'keycert.pem') 19 20# This one's based on HTTPServer, which is based on SocketServer 21 22class HTTPSServer(_HTTPServer): 23 24 def __init__(self, server_address, handler_class, context): 25 _HTTPServer.__init__(self, server_address, handler_class) 26 self.context = context 27 28 def __str__(self): 29 return ('<%s %s:%s>' % 30 (self.__class__.__name__, 31 self.server_name, 32 self.server_port)) 33 34 def get_request(self): 35 # override this to wrap socket with SSL 36 try: 37 sock, addr = self.socket.accept() 38 sslconn = self.context.wrap_socket(sock, server_side=True) 39 except socket.error as e: 40 # socket errors are silenced by the caller, print them here 41 if support.verbose: 42 sys.stderr.write("Got an error:\n%s\n" % e) 43 raise 44 return sslconn, addr 45 46class RootedHTTPRequestHandler(SimpleHTTPRequestHandler): 47 # need to override translate_path to get a known root, 48 # instead of using os.curdir, since the test could be 49 # run from anywhere 50 51 server_version = "TestHTTPS/1.0" 52 root = here 53 # Avoid hanging when a request gets interrupted by the client 54 timeout = 5 55 56 def translate_path(self, path): 57 """Translate a /-separated PATH to the local filename syntax. 58 59 Components that mean special things to the local file system 60 (e.g. drive or directory names) are ignored. (XXX They should 61 probably be diagnosed.) 62 63 """ 64 # abandon query parameters 65 path = urllib.parse.urlparse(path)[2] 66 path = os.path.normpath(urllib.parse.unquote(path)) 67 words = path.split('/') 68 words = filter(None, words) 69 path = self.root 70 for word in words: 71 drive, word = os.path.splitdrive(word) 72 head, word = os.path.split(word) 73 path = os.path.join(path, word) 74 return path 75 76 def log_message(self, format, *args): 77 # we override this to suppress logging unless "verbose" 78 if support.verbose: 79 sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" % 80 (self.server.server_address, 81 self.server.server_port, 82 self.request.cipher(), 83 self.log_date_time_string(), 84 format%args)) 85 86 87class StatsRequestHandler(BaseHTTPRequestHandler): 88 """Example HTTP request handler which returns SSL statistics on GET 89 requests. 90 """ 91 92 server_version = "StatsHTTPS/1.0" 93 94 def do_GET(self, send_body=True): 95 """Serve a GET request.""" 96 sock = self.rfile.raw._sock 97 context = sock.context 98 stats = { 99 'session_cache': context.session_stats(), 100 'cipher': sock.cipher(), 101 'compression': sock.compression(), 102 } 103 body = pprint.pformat(stats) 104 body = body.encode('utf-8') 105 self.send_response(200) 106 self.send_header("Content-type", "text/plain; charset=utf-8") 107 self.send_header("Content-Length", str(len(body))) 108 self.end_headers() 109 if send_body: 110 self.wfile.write(body) 111 112 def do_HEAD(self): 113 """Serve a HEAD request.""" 114 self.do_GET(send_body=False) 115 116 def log_request(self, format, *args): 117 if support.verbose: 118 BaseHTTPRequestHandler.log_request(self, format, *args) 119 120 121class HTTPSServerThread(threading.Thread): 122 123 def __init__(self, context, host=HOST, handler_class=None): 124 self.flag = None 125 self.server = HTTPSServer((host, 0), 126 handler_class or RootedHTTPRequestHandler, 127 context) 128 self.port = self.server.server_port 129 threading.Thread.__init__(self) 130 self.daemon = True 131 132 def __str__(self): 133 return "<%s %s>" % (self.__class__.__name__, self.server) 134 135 def start(self, flag=None): 136 self.flag = flag 137 threading.Thread.start(self) 138 139 def run(self): 140 if self.flag: 141 self.flag.set() 142 try: 143 self.server.serve_forever(0.05) 144 finally: 145 self.server.server_close() 146 147 def stop(self): 148 self.server.shutdown() 149 150 151def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None): 152 # we assume the certfile contains both private key and certificate 153 context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) 154 context.load_cert_chain(certfile) 155 server = HTTPSServerThread(context, host, handler_class) 156 flag = threading.Event() 157 server.start(flag) 158 flag.wait() 159 def cleanup(): 160 if support.verbose: 161 sys.stdout.write('stopping HTTPS server\n') 162 server.stop() 163 if support.verbose: 164 sys.stdout.write('joining HTTPS thread\n') 165 server.join() 166 case.addCleanup(cleanup) 167 return server 168 169 170if __name__ == "__main__": 171 import argparse 172 parser = argparse.ArgumentParser( 173 description='Run a test HTTPS server. ' 174 'By default, the current directory is served.') 175 parser.add_argument('-p', '--port', type=int, default=4433, 176 help='port to listen on (default: %(default)s)') 177 parser.add_argument('-q', '--quiet', dest='verbose', default=True, 178 action='store_false', help='be less verbose') 179 parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False, 180 action='store_true', help='always return stats page') 181 parser.add_argument('--curve-name', dest='curve_name', type=str, 182 action='store', 183 help='curve name for EC-based Diffie-Hellman') 184 parser.add_argument('--dh', dest='dh_file', type=str, action='store', 185 help='PEM file containing DH parameters') 186 args = parser.parse_args() 187 188 support.verbose = args.verbose 189 if args.use_stats_handler: 190 handler_class = StatsRequestHandler 191 else: 192 handler_class = RootedHTTPRequestHandler 193 if utils.PY2: 194 handler_class.root = os.getcwdu() 195 else: 196 handler_class.root = os.getcwd() 197 context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) 198 context.load_cert_chain(CERTFILE) 199 if args.curve_name: 200 context.set_ecdh_curve(args.curve_name) 201 if args.dh_file: 202 context.load_dh_params(args.dh_file) 203 204 server = HTTPSServer(("", args.port), handler_class, context) 205 if args.verbose: 206 print("Listening on https://localhost:{0.port}".format(args)) 207 server.serve_forever(0.1) 208