1import json
2import os
3import signal
4import socket
5import sys
6import time
7
8from mozlog import get_default_logger, handlers
9
10from . import mpcontext
11from .wptlogging import LogLevelRewriter, QueueHandler, LogQueueThread
12
13here = os.path.dirname(__file__)
14repo_root = os.path.abspath(os.path.join(here, os.pardir, os.pardir, os.pardir))
15
16sys.path.insert(0, repo_root)
17from tools import localpaths  # noqa: F401
18
19from wptserve.handlers import StringHandler
20
21serve = None
22
23
24def do_delayed_imports(logger, test_paths):
25    global serve
26
27    serve_root = serve_path(test_paths)
28    sys.path.insert(0, serve_root)
29
30    failed = []
31
32    try:
33        from tools.serve import serve
34    except ImportError:
35        failed.append("serve")
36
37    if failed:
38        logger.critical(
39            "Failed to import %s. Ensure that tests path %s contains web-platform-tests" %
40            (", ".join(failed), serve_root))
41        sys.exit(1)
42
43
44def serve_path(test_paths):
45    return test_paths["/"]["tests_path"]
46
47
48class TestEnvironmentError(Exception):
49    pass
50
51
52def get_server_logger():
53    logger = get_default_logger(component="wptserve")
54    log_filter = handlers.LogLevelFilter(lambda x: x, "info")
55    # Downgrade errors to warnings for the server
56    log_filter = LogLevelRewriter(log_filter, ["error"], "warning")
57    logger.component_filter = log_filter
58    return logger
59
60
61class ProxyLoggingContext:
62    """Context manager object that handles setup and teardown of a log queue
63    for handling logging messages from wptserve."""
64
65    def __init__(self, logger):
66        mp_context = mpcontext.get_context()
67        self.log_queue = mp_context.Queue()
68        self.logging_thread = LogQueueThread(self.log_queue, logger)
69        self.logger_handler = QueueHandler(self.log_queue)
70
71    def __enter__(self):
72        self.logging_thread.start()
73        return self.logger_handler
74
75    def __exit__(self, *args):
76        self.log_queue.put(None)
77        # Wait for thread to shut down but not for too long since it's a daemon
78        self.logging_thread.join(1)
79
80
81class TestEnvironment(object):
82    """Context manager that owns the test environment i.e. the http and
83    websockets servers"""
84    def __init__(self, test_paths, testharness_timeout_multipler,
85                 pause_after_test, debug_test, debug_info, options, ssl_config, env_extras,
86                 enable_quic=False, mojojs_path=None):
87
88        self.test_paths = test_paths
89        self.server = None
90        self.config_ctx = None
91        self.config = None
92        self.server_logger = get_server_logger()
93        self.server_logging_ctx = ProxyLoggingContext(self.server_logger)
94        self.testharness_timeout_multipler = testharness_timeout_multipler
95        self.pause_after_test = pause_after_test
96        self.debug_test = debug_test
97        self.test_server_port = options.pop("test_server_port", True)
98        self.debug_info = debug_info
99        self.options = options if options is not None else {}
100
101        mp_context = mpcontext.get_context()
102        self.cache_manager = mp_context.Manager()
103        self.stash = serve.stash.StashServer(mp_context=mp_context)
104        self.env_extras = env_extras
105        self.env_extras_cms = None
106        self.ssl_config = ssl_config
107        self.enable_quic = enable_quic
108        self.mojojs_path = mojojs_path
109
110    def __enter__(self):
111        server_log_handler = self.server_logging_ctx.__enter__()
112        self.config_ctx = self.build_config()
113
114        self.config = self.config_ctx.__enter__()
115
116        self.stash.__enter__()
117        self.cache_manager.__enter__()
118
119        assert self.env_extras_cms is None, (
120            "A TestEnvironment object cannot be nested")
121
122        self.env_extras_cms = []
123
124        for env in self.env_extras:
125            cm = env(self.options, self.config)
126            cm.__enter__()
127            self.env_extras_cms.append(cm)
128
129        self.servers = serve.start(self.server_logger,
130                                   self.config,
131                                   self.get_routes(),
132                                   mp_context=mpcontext.get_context(),
133                                   log_handlers=[server_log_handler])
134
135        if self.options.get("supports_debugger") and self.debug_info and self.debug_info.interactive:
136            self.ignore_interrupts()
137        return self
138
139    def __exit__(self, exc_type, exc_val, exc_tb):
140        self.process_interrupts()
141
142        for scheme, servers in self.servers.items():
143            for port, server in servers:
144                server.stop()
145        for cm in self.env_extras_cms:
146            cm.__exit__(exc_type, exc_val, exc_tb)
147
148        self.env_extras_cms = None
149
150        self.cache_manager.__exit__(exc_type, exc_val, exc_tb)
151        self.stash.__exit__()
152        self.config_ctx.__exit__(exc_type, exc_val, exc_tb)
153        self.server_logging_ctx.__exit__(exc_type, exc_val, exc_tb)
154
155    def ignore_interrupts(self):
156        signal.signal(signal.SIGINT, signal.SIG_IGN)
157
158    def process_interrupts(self):
159        signal.signal(signal.SIGINT, signal.SIG_DFL)
160
161    def build_config(self):
162        override_path = os.path.join(serve_path(self.test_paths), "config.json")
163
164        config = serve.ConfigBuilder(self.server_logger)
165
166        ports = {
167            "http": [8000, 8001],
168            "http-private": [8002],
169            "http-public": [8003],
170            "https": [8443, 8444],
171            "https-private": [8445],
172            "https-public": [8446],
173            "ws": [8888],
174            "wss": [8889],
175            "h2": [9000],
176        }
177        if self.enable_quic:
178            ports["quic-transport"] = [10000]
179        config.ports = ports
180
181        if os.path.exists(override_path):
182            with open(override_path) as f:
183                override_obj = json.load(f)
184            config.update(override_obj)
185
186        config.check_subdomains = False
187
188        ssl_config = self.ssl_config.copy()
189        ssl_config["encrypt_after_connect"] = self.options.get("encrypt_after_connect", False)
190        config.ssl = ssl_config
191
192        if "browser_host" in self.options:
193            config.browser_host = self.options["browser_host"]
194
195        if "bind_address" in self.options:
196            config.bind_address = self.options["bind_address"]
197
198        config.server_host = self.options.get("server_host", None)
199        config.doc_root = serve_path(self.test_paths)
200
201        return config
202
203    def get_routes(self):
204        route_builder = serve.RoutesBuilder()
205
206        for path, format_args, content_type, route in [
207                ("testharness_runner.html", {}, "text/html", "/testharness_runner.html"),
208                ("print_reftest_runner.html", {}, "text/html", "/print_reftest_runner.html"),
209                (os.path.join(here, "..", "..", "third_party", "pdf_js", "pdf.js"), None,
210                 "text/javascript", "/_pdf_js/pdf.js"),
211                (os.path.join(here, "..", "..", "third_party", "pdf_js", "pdf.worker.js"), None,
212                 "text/javascript", "/_pdf_js/pdf.worker.js"),
213                (self.options.get("testharnessreport", "testharnessreport.js"),
214                 {"output": self.pause_after_test,
215                  "timeout_multiplier": self.testharness_timeout_multipler,
216                  "explicit_timeout": "true" if self.debug_info is not None else "false",
217                  "debug": "true" if self.debug_test else "false"},
218                 "text/javascript;charset=utf8",
219                 "/resources/testharnessreport.js")]:
220            path = os.path.normpath(os.path.join(here, path))
221            # Note that .headers. files don't apply to static routes, so we need to
222            # readd any static headers here.
223            headers = {"Cache-Control": "max-age=3600"}
224            route_builder.add_static(path, format_args, content_type, route,
225                                     headers=headers)
226
227        data = b""
228        with open(os.path.join(repo_root, "resources", "testdriver.js"), "rb") as fp:
229            data += fp.read()
230        with open(os.path.join(here, "testdriver-extra.js"), "rb") as fp:
231            data += fp.read()
232        route_builder.add_handler("GET", "/resources/testdriver.js",
233                                  StringHandler(data, "text/javascript"))
234
235        for url_base, paths in self.test_paths.items():
236            if url_base == "/":
237                continue
238            route_builder.add_mount_point(url_base, paths["tests_path"])
239
240        if "/" not in self.test_paths:
241            del route_builder.mountpoint_routes["/"]
242
243        if self.mojojs_path:
244            route_builder.add_mount_point("/gen/", self.mojojs_path)
245
246        return route_builder.get_routes()
247
248    def ensure_started(self):
249        # Pause for a while to ensure that the server has a chance to start
250        total_sleep_secs = 30
251        each_sleep_secs = 0.5
252        end_time = time.time() + total_sleep_secs
253        while time.time() < end_time:
254            failed, pending = self.test_servers()
255            if failed:
256                break
257            if not pending:
258                return
259            time.sleep(each_sleep_secs)
260        raise EnvironmentError("Servers failed to start: %s" %
261                               ", ".join("%s:%s" % item for item in failed))
262
263    def test_servers(self):
264        failed = []
265        pending = []
266        host = self.config["server_host"]
267        for scheme, servers in self.servers.items():
268            for port, server in servers:
269                if not server.is_alive():
270                    failed.append((scheme, port))
271
272        if not failed and self.test_server_port:
273            for scheme, servers in self.servers.items():
274                # TODO(Hexcles): Find a way to test QUIC's UDP port.
275                if scheme == "quic-transport":
276                    continue
277                for port, server in servers:
278                    s = socket.socket()
279                    s.settimeout(0.1)
280                    try:
281                        s.connect((host, port))
282                    except OSError:
283                        pending.append((host, port))
284                    finally:
285                        s.close()
286
287        return failed, pending
288