1import base64
2import hashlib
3import io
4import json
5import os
6import threading
7import traceback
8import socket
9import sys
10from abc import ABCMeta, abstractmethod
11from http.client import HTTPConnection
12from typing import Any, Callable, ClassVar, Optional, Tuple, Type, TYPE_CHECKING
13from urllib.parse import urljoin, urlsplit, urlunsplit
14
15from .actions import actions
16from .protocol import Protocol, BaseProtocolPart
17
18if TYPE_CHECKING:
19    from ..webdriver_server import WebDriverServer
20
21here = os.path.dirname(__file__)
22
23
24def executor_kwargs(test_type, test_environment, run_info_data, **kwargs):
25    timeout_multiplier = kwargs["timeout_multiplier"]
26    if timeout_multiplier is None:
27        timeout_multiplier = 1
28
29    executor_kwargs = {"server_config": test_environment.config,
30                       "timeout_multiplier": timeout_multiplier,
31                       "debug_info": kwargs["debug_info"]}
32
33    if test_type in ("reftest", "print-reftest"):
34        executor_kwargs["screenshot_cache"] = test_environment.cache_manager.dict()
35
36    if test_type == "wdspec":
37        executor_kwargs["binary"] = kwargs.get("binary")
38        executor_kwargs["webdriver_binary"] = kwargs.get("webdriver_binary")
39        executor_kwargs["webdriver_args"] = kwargs.get("webdriver_args")
40
41    # By default the executor may try to cleanup windows after a test (to best
42    # associate any problems with the test causing them). If the user might
43    # want to view the results, however, the executor has to skip that cleanup.
44    if kwargs["pause_after_test"] or kwargs["pause_on_unexpected"]:
45        executor_kwargs["cleanup_after_test"] = False
46    executor_kwargs["debug_test"] = kwargs["debug_test"]
47    return executor_kwargs
48
49
50def strip_server(url):
51    """Remove the scheme and netloc from a url, leaving only the path and any query
52    or fragment.
53
54    url - the url to strip
55
56    e.g. http://example.org:8000/tests?id=1#2 becomes /tests?id=1#2"""
57
58    url_parts = list(urlsplit(url))
59    url_parts[0] = ""
60    url_parts[1] = ""
61    return urlunsplit(url_parts)
62
63
64class TestharnessResultConverter(object):
65    harness_codes = {0: "OK",
66                     1: "ERROR",
67                     2: "TIMEOUT",
68                     3: "PRECONDITION_FAILED"}
69
70    test_codes = {0: "PASS",
71                  1: "FAIL",
72                  2: "TIMEOUT",
73                  3: "NOTRUN",
74                  4: "PRECONDITION_FAILED"}
75
76    def __call__(self, test, result, extra=None):
77        """Convert a JSON result into a (TestResult, [SubtestResult]) tuple"""
78        result_url, status, message, stack, subtest_results = result
79        assert result_url == test.url, ("Got results from %s, expected %s" %
80                                        (result_url, test.url))
81        harness_result = test.result_cls(self.harness_codes[status], message, extra=extra, stack=stack)
82        return (harness_result,
83                [test.subtest_result_cls(st_name, self.test_codes[st_status], st_message, st_stack)
84                 for st_name, st_status, st_message, st_stack in subtest_results])
85
86
87testharness_result_converter = TestharnessResultConverter()
88
89
90def hash_screenshots(screenshots):
91    """Computes the sha1 checksum of a list of base64-encoded screenshots."""
92    return [hashlib.sha1(base64.b64decode(screenshot)).hexdigest()
93            for screenshot in screenshots]
94
95
96def _ensure_hash_in_reftest_screenshots(extra):
97    """Make sure reftest_screenshots have hashes.
98
99    Marionette internal reftest runner does not produce hashes.
100    """
101    log_data = extra.get("reftest_screenshots")
102    if not log_data:
103        return
104    for item in log_data:
105        if type(item) != dict:
106            # Skip relation strings.
107            continue
108        if "hash" not in item:
109            item["hash"] = hash_screenshots([item["screenshot"]])[0]
110
111
112def get_pages(ranges_value, total_pages):
113    """Get a set of page numbers to include in a print reftest.
114
115    :param ranges_value: Parsed page ranges as a list e.g. [[1,2], [4], [6,None]]
116    :param total_pages: Integer total number of pages in the paginated output.
117    :retval: Set containing integer page numbers to include in the comparison e.g.
118             for the example ranges value and 10 total pages this would be
119             {1,2,4,6,7,8,9,10}"""
120    if not ranges_value:
121        return set(range(1, total_pages + 1))
122
123    rv = set()
124
125    for range_limits in ranges_value:
126        if len(range_limits) == 1:
127            range_limits = [range_limits[0], range_limits[0]]
128
129        if range_limits[0] is None:
130            range_limits[0] = 1
131        if range_limits[1] is None:
132            range_limits[1] = total_pages
133
134        if range_limits[0] > total_pages:
135            continue
136        rv |= set(range(range_limits[0], range_limits[1] + 1))
137    return rv
138
139
140def reftest_result_converter(self, test, result):
141    extra = result.get("extra", {})
142    _ensure_hash_in_reftest_screenshots(extra)
143    return (test.result_cls(
144        result["status"],
145        result["message"],
146        extra=extra,
147        stack=result.get("stack")), [])
148
149
150def pytest_result_converter(self, test, data):
151    harness_data, subtest_data = data
152
153    if subtest_data is None:
154        subtest_data = []
155
156    harness_result = test.result_cls(*harness_data)
157    subtest_results = [test.subtest_result_cls(*item) for item in subtest_data]
158
159    return (harness_result, subtest_results)
160
161
162def crashtest_result_converter(self, test, result):
163    return test.result_cls(**result), []
164
165
166class ExecutorException(Exception):
167    def __init__(self, status, message):
168        self.status = status
169        self.message = message
170
171
172class TimedRunner(object):
173    def __init__(self, logger, func, protocol, url, timeout, extra_timeout):
174        self.func = func
175        self.logger = logger
176        self.result = None
177        self.protocol = protocol
178        self.url = url
179        self.timeout = timeout
180        self.extra_timeout = extra_timeout
181        self.result_flag = threading.Event()
182
183    def run(self):
184        for setup_fn in [self.set_timeout, self.before_run]:
185            err = setup_fn()
186            if err:
187                self.result = (False, err)
188                return self.result
189
190        executor = threading.Thread(target=self.run_func)
191        executor.start()
192
193        # Add twice the extra timeout since the called function is expected to
194        # wait at least self.timeout + self.extra_timeout and this gives some leeway
195        timeout = self.timeout + 2 * self.extra_timeout if self.timeout else None
196        finished = self.result_flag.wait(timeout)
197        if self.result is None:
198            if finished:
199                # flag is True unless we timeout; this *shouldn't* happen, but
200                # it can if self.run_func fails to set self.result due to raising
201                self.result = False, ("INTERNAL-ERROR", "%s.run_func didn't set a result" %
202                                      self.__class__.__name__)
203            else:
204                if self.protocol.is_alive():
205                    message = "Executor hit external timeout (this may indicate a hang)\n"
206                    # get a traceback for the current stack of the executor thread
207                    message += "".join(traceback.format_stack(sys._current_frames()[executor.ident]))
208                    self.result = False, ("EXTERNAL-TIMEOUT", message)
209                else:
210                    self.logger.info("Browser not responding, setting status to CRASH")
211                    self.result = False, ("CRASH", None)
212        elif self.result[1] is None:
213            # We didn't get any data back from the test, so check if the
214            # browser is still responsive
215            if self.protocol.is_alive():
216                self.result = False, ("INTERNAL-ERROR", None)
217            else:
218                self.logger.info("Browser not responding, setting status to CRASH")
219                self.result = False, ("CRASH", None)
220
221        return self.result
222
223    def set_timeout(self):
224        raise NotImplementedError
225
226    def before_run(self):
227        pass
228
229    def run_func(self):
230        raise NotImplementedError
231
232
233class TestExecutor(object):
234    """Abstract Base class for object that actually executes the tests in a
235    specific browser. Typically there will be a different TestExecutor
236    subclass for each test type and method of executing tests.
237
238    :param browser: ExecutorBrowser instance providing properties of the
239                    browser that will be tested.
240    :param server_config: Dictionary of wptserve server configuration of the
241                          form stored in TestEnvironment.config
242    :param timeout_multiplier: Multiplier relative to base timeout to use
243                               when setting test timeout.
244    """
245    __metaclass__ = ABCMeta
246
247    test_type = None  # type: ClassVar[str]
248    # convert_result is a class variable set to a callable converter
249    # (e.g. reftest_result_converter) converting from an instance of
250    # URLManifestItem (e.g. RefTest) + type-dependent results object +
251    # type-dependent extra data, returning a tuple of Result and list of
252    # SubtestResult. For now, any callable is accepted. TODO: Make this type
253    # stricter when more of the surrounding code is annotated.
254    convert_result = None  # type: ClassVar[Callable[..., Any]]
255    supports_testdriver = False
256    supports_jsshell = False
257    # Extra timeout to use after internal test timeout at which the harness
258    # should force a timeout
259    extra_timeout = 5  # seconds
260
261
262    def __init__(self, logger, browser, server_config, timeout_multiplier=1,
263                 debug_info=None, **kwargs):
264        self.logger = logger
265        self.runner = None
266        self.browser = browser
267        self.server_config = server_config
268        self.timeout_multiplier = timeout_multiplier
269        self.debug_info = debug_info
270        self.last_environment = {"protocol": "http",
271                                 "prefs": {}}
272        self.protocol = None  # This must be set in subclasses
273
274    def setup(self, runner):
275        """Run steps needed before tests can be started e.g. connecting to
276        browser instance
277
278        :param runner: TestRunner instance that is going to run the tests"""
279        self.runner = runner
280        if self.protocol is not None:
281            self.protocol.setup(runner)
282
283    def teardown(self):
284        """Run cleanup steps after tests have finished"""
285        if self.protocol is not None:
286            self.protocol.teardown()
287
288    def reset(self):
289        """Re-initialize internal state to facilitate repeated test execution
290        as implemented by the `--rerun` command-line argument."""
291        pass
292
293    def run_test(self, test):
294        """Run a particular test.
295
296        :param test: The test to run"""
297        try:
298            if test.environment != self.last_environment:
299                self.on_environment_change(test.environment)
300            result = self.do_test(test)
301        except Exception as e:
302            exception_string = traceback.format_exc()
303            self.logger.warning(exception_string)
304            result = self.result_from_exception(test, e, exception_string)
305
306        # log result of parent test
307        if result[0].status == "ERROR":
308            self.logger.debug(result[0].message)
309
310        self.last_environment = test.environment
311
312        self.runner.send_message("test_ended", test, result)
313
314    def server_url(self, protocol, subdomain=False):
315        scheme = "https" if protocol == "h2" else protocol
316        host = self.server_config["browser_host"]
317        if subdomain:
318            # The only supported subdomain filename flag is "www".
319            host = "{subdomain}.{host}".format(subdomain="www", host=host)
320        return "{scheme}://{host}:{port}".format(scheme=scheme, host=host,
321            port=self.server_config["ports"][protocol][0])
322
323    def test_url(self, test):
324        return urljoin(self.server_url(test.environment["protocol"],
325                                       test.subdomain), test.url)
326
327    @abstractmethod
328    def do_test(self, test):
329        """Test-type and protocol specific implementation of running a
330        specific test.
331
332        :param test: The test to run."""
333        pass
334
335    def on_environment_change(self, new_environment):
336        pass
337
338    def result_from_exception(self, test, e, exception_string):
339        if hasattr(e, "status") and e.status in test.result_cls.statuses:
340            status = e.status
341        else:
342            status = "INTERNAL-ERROR"
343        message = str(getattr(e, "message", ""))
344        if message:
345            message += "\n"
346        message += exception_string
347        return test.result_cls(status, message), []
348
349    def wait(self):
350        self.protocol.base.wait()
351
352
353class TestharnessExecutor(TestExecutor):
354    convert_result = testharness_result_converter
355
356
357class RefTestExecutor(TestExecutor):
358    convert_result = reftest_result_converter
359    is_print = False
360
361    def __init__(self, logger, browser, server_config, timeout_multiplier=1, screenshot_cache=None,
362                 debug_info=None, **kwargs):
363        TestExecutor.__init__(self, logger, browser, server_config,
364                              timeout_multiplier=timeout_multiplier,
365                              debug_info=debug_info)
366
367        self.screenshot_cache = screenshot_cache
368
369
370class CrashtestExecutor(TestExecutor):
371    convert_result = crashtest_result_converter
372
373
374class PrintRefTestExecutor(TestExecutor):
375    convert_result = reftest_result_converter
376    is_print = True
377
378
379class RefTestImplementation(object):
380    def __init__(self, executor):
381        self.timeout_multiplier = executor.timeout_multiplier
382        self.executor = executor
383        # Cache of url:(screenshot hash, screenshot). Typically the
384        # screenshot is None, but we set this value if a test fails
385        # and the screenshot was taken from the cache so that we may
386        # retrieve the screenshot from the cache directly in the future
387        self.screenshot_cache = self.executor.screenshot_cache
388        self.message = None
389
390    def setup(self):
391        pass
392
393    def teardown(self):
394        pass
395
396    @property
397    def logger(self):
398        return self.executor.logger
399
400    def get_hash(self, test, viewport_size, dpi, page_ranges):
401        key = (test.url, viewport_size, dpi)
402
403        if key not in self.screenshot_cache:
404            success, data = self.get_screenshot_list(test, viewport_size, dpi, page_ranges)
405
406            if not success:
407                return False, data
408
409            screenshots = data
410            hash_values = hash_screenshots(data)
411            self.screenshot_cache[key] = (hash_values, screenshots)
412
413            rv = (hash_values, screenshots)
414        else:
415            rv = self.screenshot_cache[key]
416
417        self.message.append("%s %s" % (test.url, rv[0]))
418        return True, rv
419
420    def reset(self):
421        self.screenshot_cache.clear()
422
423    def check_pass(self, hashes, screenshots, urls, relation, fuzzy):
424        """Check if a test passes, and return a tuple of (pass, page_idx),
425        where page_idx is the zero-based index of the first page on which a
426        difference occurs if any, or None if there are no differences"""
427
428        assert relation in ("==", "!=")
429        lhs_hashes, rhs_hashes = hashes
430        lhs_screenshots, rhs_screenshots = screenshots
431
432        if len(lhs_hashes) != len(rhs_hashes):
433            self.logger.info("Got different number of pages")
434            return relation == "!=", None
435
436        assert len(lhs_screenshots) == len(lhs_hashes) == len(rhs_screenshots) == len(rhs_hashes)
437
438        for (page_idx, (lhs_hash,
439                        rhs_hash,
440                        lhs_screenshot,
441                        rhs_screenshot)) in enumerate(zip(lhs_hashes,
442                                                          rhs_hashes,
443                                                          lhs_screenshots,
444                                                          rhs_screenshots)):
445            comparison_screenshots = (lhs_screenshot, rhs_screenshot)
446            if not fuzzy or fuzzy == ((0, 0), (0, 0)):
447                equal = lhs_hash == rhs_hash
448                # sometimes images can have different hashes, but pixels can be identical.
449                if not equal:
450                    self.logger.info("Image hashes didn't match%s, checking pixel differences" %
451                                     ("" if len(hashes) == 1 else " on page %i" % (page_idx + 1)))
452                    max_per_channel, pixels_different = self.get_differences(comparison_screenshots,
453                                                                             urls)
454                    equal = pixels_different == 0 and max_per_channel == 0
455            else:
456                max_per_channel, pixels_different = self.get_differences(comparison_screenshots,
457                                                                         urls,
458                                                                         page_idx if len(hashes) > 1 else None)
459                allowed_per_channel, allowed_different = fuzzy
460                self.logger.info("Allowed %s pixels different, maximum difference per channel %s" %
461                                 ("-".join(str(item) for item in allowed_different),
462                                  "-".join(str(item) for item in allowed_per_channel)))
463                equal = ((pixels_different == 0 and allowed_different[0] == 0) or
464                         (max_per_channel == 0 and allowed_per_channel[0] == 0) or
465                         (allowed_per_channel[0] <= max_per_channel <= allowed_per_channel[1] and
466                          allowed_different[0] <= pixels_different <= allowed_different[1]))
467            if not equal:
468                return (False if relation == "==" else True, page_idx)
469        # All screenshots were equal within the fuzziness
470        return (True if relation == "==" else False, None)
471
472    def get_differences(self, screenshots, urls, page_idx=None):
473        from PIL import Image, ImageChops, ImageStat
474
475        lhs = Image.open(io.BytesIO(base64.b64decode(screenshots[0]))).convert("RGB")
476        rhs = Image.open(io.BytesIO(base64.b64decode(screenshots[1]))).convert("RGB")
477        self.check_if_solid_color(lhs, urls[0])
478        self.check_if_solid_color(rhs, urls[1])
479        diff = ImageChops.difference(lhs, rhs)
480        minimal_diff = diff.crop(diff.getbbox())
481        mask = minimal_diff.convert("L", dither=None)
482        stat = ImageStat.Stat(minimal_diff, mask)
483        per_channel = max(item[1] for item in stat.extrema)
484        count = stat.count[0]
485        self.logger.info("Found %s pixels different, maximum difference per channel %s%s" %
486                         (count,
487                          per_channel,
488                          "" if page_idx is None else " on page %i" % (page_idx + 1)))
489        return per_channel, count
490
491    def check_if_solid_color(self, image, url):
492        extrema = image.getextrema()
493        if all(min == max for min, max in extrema):
494            color = ''.join('%02X' % value for value, _ in extrema)
495            self.message.append("Screenshot is solid color 0x%s for %s\n" % (color, url))
496
497    def run_test(self, test):
498        viewport_size = test.viewport_size
499        dpi = test.dpi
500        page_ranges = test.page_ranges
501        self.message = []
502
503
504        # Depth-first search of reference tree, with the goal
505        # of reachings a leaf node with only pass results
506
507        stack = list(((test, item[0]), item[1]) for item in reversed(test.references))
508        page_idx = None
509        while stack:
510            hashes = [None, None]
511            screenshots = [None, None]
512            urls = [None, None]
513
514            nodes, relation = stack.pop()
515            fuzzy = self.get_fuzzy(test, nodes, relation)
516
517            for i, node in enumerate(nodes):
518                success, data = self.get_hash(node, viewport_size, dpi, page_ranges)
519                if success is False:
520                    return {"status": data[0], "message": data[1]}
521
522                hashes[i], screenshots[i] = data
523                urls[i] = node.url
524
525            is_pass, page_idx = self.check_pass(hashes, screenshots, urls, relation, fuzzy)
526            if is_pass:
527                fuzzy = self.get_fuzzy(test, nodes, relation)
528                if nodes[1].references:
529                    stack.extend(list(((nodes[1], item[0]), item[1])
530                                      for item in reversed(nodes[1].references)))
531                else:
532                    # We passed
533                    return {"status": "PASS", "message": None}
534
535        # We failed, so construct a failure message
536
537        if page_idx is None:
538            # default to outputting the last page
539            page_idx = -1
540        for i, (node, screenshot) in enumerate(zip(nodes, screenshots)):
541            if screenshot is None:
542                success, screenshot = self.retake_screenshot(node, viewport_size, dpi, page_ranges)
543                if success:
544                    screenshots[i] = screenshot
545
546        log_data = [
547            {"url": nodes[0].url,
548             "screenshot": screenshots[0][page_idx],
549             "hash": hashes[0][page_idx]},
550            relation,
551            {"url": nodes[1].url,
552             "screenshot": screenshots[1][page_idx],
553             "hash": hashes[1][page_idx]},
554        ]
555
556        return {"status": "FAIL",
557                "message": "\n".join(self.message),
558                "extra": {"reftest_screenshots": log_data}}
559
560    def get_fuzzy(self, root_test, test_nodes, relation):
561        full_key = tuple([item.url for item in test_nodes] + [relation])
562        ref_only_key = test_nodes[1].url
563
564        fuzzy_override = root_test.fuzzy_override
565        fuzzy = test_nodes[0].fuzzy
566
567        sources = [fuzzy_override, fuzzy]
568        keys = [full_key, ref_only_key, None]
569        value = None
570        for source in sources:
571            for key in keys:
572                if key in source:
573                    value = source[key]
574                    break
575            if value:
576                break
577        return value
578
579    def retake_screenshot(self, node, viewport_size, dpi, page_ranges):
580        success, data = self.get_screenshot_list(node,
581                                                 viewport_size,
582                                                 dpi,
583                                                 page_ranges)
584        if not success:
585            return False, data
586
587        key = (node.url, viewport_size, dpi)
588        hash_val, _ = self.screenshot_cache[key]
589        self.screenshot_cache[key] = hash_val, data
590        return True, data
591
592    def get_screenshot_list(self, node, viewport_size, dpi, page_ranges):
593        success, data = self.executor.screenshot(node, viewport_size, dpi, page_ranges)
594        if success and not isinstance(data, list):
595            return success, [data]
596        return success, data
597
598
599class WdspecExecutor(TestExecutor):
600    convert_result = pytest_result_converter
601    protocol_cls = None  # type: ClassVar[Type[Protocol]]
602
603    def __init__(self, logger, browser, server_config, webdriver_binary,
604                 webdriver_args, timeout_multiplier=1, capabilities=None,
605                 debug_info=None, environ=None, **kwargs):
606        self.do_delayed_imports()
607        TestExecutor.__init__(self, logger, browser, server_config,
608                              timeout_multiplier=timeout_multiplier,
609                              debug_info=debug_info)
610        self.webdriver_binary = webdriver_binary
611        self.webdriver_args = webdriver_args
612        self.timeout_multiplier = timeout_multiplier
613        self.capabilities = capabilities
614        self.environ = environ if environ is not None else {}
615        self.output_handler_kwargs = None
616        self.output_handler_start_kwargs = None
617
618    def setup(self, runner):
619        self.protocol = self.protocol_cls(self, self.browser)
620        super().setup(runner)
621
622    def is_alive(self):
623        return self.protocol.is_alive()
624
625    def on_environment_change(self, new_environment):
626        pass
627
628    def do_test(self, test):
629        timeout = test.timeout * self.timeout_multiplier + self.extra_timeout
630
631        success, data = WdspecRun(self.do_wdspec,
632                                  self.protocol.session_config,
633                                  test.abs_path,
634                                  timeout).run()
635
636        if success:
637            return self.convert_result(test, data)
638
639        return (test.result_cls(*data), [])
640
641    def do_wdspec(self, session_config, path, timeout):
642        return pytestrunner.run(path,
643                                self.server_config,
644                                session_config,
645                                timeout=timeout)
646
647    def do_delayed_imports(self):
648        global pytestrunner
649        from . import pytestrunner
650
651
652class WdspecRun(object):
653    def __init__(self, func, session, path, timeout):
654        self.func = func
655        self.result = (None, None)
656        self.session = session
657        self.path = path
658        self.timeout = timeout
659        self.result_flag = threading.Event()
660
661    def run(self):
662        """Runs function in a thread and interrupts it if it exceeds the
663        given timeout.  Returns (True, (Result, [SubtestResult ...])) in
664        case of success, or (False, (status, extra information)) in the
665        event of failure.
666        """
667
668        executor = threading.Thread(target=self._run)
669        executor.start()
670
671        self.result_flag.wait(self.timeout)
672        if self.result[1] is None:
673            self.result = False, ("EXTERNAL-TIMEOUT", None)
674
675        return self.result
676
677    def _run(self):
678        try:
679            self.result = True, self.func(self.session, self.path, self.timeout)
680        except (socket.timeout, IOError):
681            self.result = False, ("CRASH", None)
682        except Exception as e:
683            message = getattr(e, "message")
684            if message:
685                message += "\n"
686            message += traceback.format_exc()
687            self.result = False, ("INTERNAL-ERROR", message)
688        finally:
689            self.result_flag.set()
690
691
692class ConnectionlessBaseProtocolPart(BaseProtocolPart):
693    def load(self, url):
694        pass
695
696    def execute_script(self, script, asynchronous=False):
697        pass
698
699    def set_timeout(self, timeout):
700        pass
701
702    def wait(self):
703        pass
704
705    def set_window(self, handle):
706        pass
707
708    def window_handles(self):
709        return []
710
711
712class ConnectionlessProtocol(Protocol):
713    implements = [ConnectionlessBaseProtocolPart]
714
715    def connect(self):
716        pass
717
718    def after_connect(self):
719        pass
720
721
722class WdspecProtocol(Protocol):
723    server_cls = None  # type: ClassVar[Optional[Type[WebDriverServer]]]
724
725    implements = [ConnectionlessBaseProtocolPart]
726
727    def __init__(self, executor, browser):
728        Protocol.__init__(self, executor, browser)
729        self.webdriver_binary = executor.webdriver_binary
730        self.webdriver_args = executor.webdriver_args
731        self.capabilities = self.executor.capabilities
732        self.session_config = None
733        self.server = None
734        self.environ = os.environ.copy()
735        self.environ.update(executor.environ)
736        self.output_handler_kwargs = executor.output_handler_kwargs
737        self.output_handler_start_kwargs = executor.output_handler_start_kwargs
738
739    def connect(self):
740        """Connect to browser via the HTTP server."""
741        self.server = self.server_cls(
742            self.logger,
743            binary=self.webdriver_binary,
744            args=self.webdriver_args,
745            env=self.environ)
746        self.server.start(block=False,
747                          output_handler_kwargs=self.output_handler_kwargs,
748                          output_handler_start_kwargs=self.output_handler_start_kwargs)
749        self.logger.info(
750            "WebDriver HTTP server listening at %s" % self.server.url)
751        self.session_config = {"host": self.server.host,
752                               "port": self.server.port,
753                               "capabilities": self.capabilities}
754
755    def after_connect(self):
756        pass
757
758    def teardown(self):
759        if self.server is not None and self.server.is_alive():
760            self.server.stop()
761
762    def is_alive(self):
763        """Test that the connection is still alive.
764
765        Because the remote communication happens over HTTP we need to
766        make an explicit request to the remote.  It is allowed for
767        WebDriver spec tests to not have a WebDriver session, since this
768        may be what is tested.
769
770        An HTTP request to an invalid path that results in a 404 is
771        proof enough to us that the server is alive and kicking.
772        """
773        conn = HTTPConnection(self.server.host, self.server.port)
774        conn.request("HEAD", self.server.base_path + "invalid")
775        res = conn.getresponse()
776        return res.status == 404
777
778
779class CallbackHandler(object):
780    """Handle callbacks from testdriver-using tests.
781
782    The default implementation here makes sense for things that are roughly like
783    WebDriver. Things that are more different to WebDriver may need to create a
784    fully custom implementation."""
785
786    unimplemented_exc = (NotImplementedError,)  # type: ClassVar[Tuple[Type[Exception], ...]]
787
788    def __init__(self, logger, protocol, test_window):
789        self.protocol = protocol
790        self.test_window = test_window
791        self.logger = logger
792        self.callbacks = {
793            "action": self.process_action,
794            "complete": self.process_complete
795        }
796
797        self.actions = {cls.name: cls(self.logger, self.protocol) for cls in actions}
798
799    def __call__(self, result):
800        url, command, payload = result
801        self.logger.debug("Got async callback: %s" % result[1])
802        try:
803            callback = self.callbacks[command]
804        except KeyError:
805            raise ValueError("Unknown callback type %r" % result[1])
806        return callback(url, payload)
807
808    def process_complete(self, url, payload):
809        rv = [strip_server(url)] + payload
810        return True, rv
811
812    def process_action(self, url, payload):
813        action = payload["action"]
814        cmd_id = payload["id"]
815        self.logger.debug("Got action: %s" % action)
816        try:
817            action_handler = self.actions[action]
818        except KeyError:
819            raise ValueError("Unknown action %s" % action)
820        try:
821            with ActionContext(self.logger, self.protocol, payload.get("context")):
822                result = action_handler(payload)
823        except self.unimplemented_exc:
824            self.logger.warning("Action %s not implemented" % action)
825            self._send_message(cmd_id, "complete", "error", "Action %s not implemented" % action)
826        except Exception:
827            self.logger.warning("Action %s failed" % action)
828            self.logger.warning(traceback.format_exc())
829            self._send_message(cmd_id, "complete", "error")
830            raise
831        else:
832            self.logger.debug("Action %s completed with result %s" % (action, result))
833            return_message = {"result": result}
834            self._send_message(cmd_id, "complete", "success", json.dumps(return_message))
835
836        return False, None
837
838    def _send_message(self, cmd_id, message_type, status, message=None):
839        self.protocol.testdriver.send_message(cmd_id, message_type, status, message=message)
840
841
842class ActionContext(object):
843    def __init__(self, logger, protocol, context):
844        self.logger = logger
845        self.protocol = protocol
846        self.context = context
847        self.initial_window = None
848
849    def __enter__(self):
850        if self.context is None:
851            return
852
853        self.initial_window = self.protocol.base.current_window
854        self.logger.debug("Switching to window %s" % self.context)
855        self.protocol.testdriver.switch_to_window(self.context, self.initial_window)
856
857    def __exit__(self, *args):
858        if self.context is None:
859            return
860
861        self.logger.debug("Switching back to initial window")
862        self.protocol.base.set_window(self.initial_window)
863        self.initial_window = None
864