1import contextlib
2import http.server
3import pathlib
4import ssl
5import threading
6
7# Generated with:
8# $ openssl req -new -x509 -days 3650 -nodes -out cert.pem \
9#     -keyout cert.pem -addext "subjectAltName = DNS:localhost"
10CERT_FILE = str(pathlib.Path(__file__).parent / "certs" / "cert.pem")
11
12
13class HttpServerThread(threading.Thread):
14    def __init__(self, handler, *args, **kwargs):
15        super().__init__(*args, **kwargs)
16        self.server = http.server.HTTPServer(("localhost", 7777), handler)
17
18    def run(self):
19        self.server.serve_forever(poll_interval=0.01)
20
21    def terminate(self):
22        self.server.shutdown()
23        self.server.server_close()
24        self.join()
25
26
27class HttpsServerThread(HttpServerThread):
28    def __init__(self, handler, *args, **kwargs):
29        super().__init__(handler, *args, **kwargs)
30        self.server.socket = ssl.wrap_socket(
31            self.server.socket,
32            certfile=CERT_FILE,
33            server_side=True,
34        )
35
36
37def create_server(thread_class):
38    def server(handler):
39        server_thread = thread_class(handler, daemon=True)
40        server_thread.start()
41        try:
42            yield server_thread
43        finally:
44            server_thread.terminate()
45    return contextlib.contextmanager(server)
46
47
48http_server = create_server(HttpServerThread)
49https_server = create_server(HttpsServerThread)
50