1import base64
2import hashlib
3import io
4import json
5import os
6import threading
7import traceback
8import socket
9import sys
10from abc import ABCMeta, abstractmethod
11from typing import Any, Callable, ClassVar, Tuple, Type
12from urllib.parse import urljoin, urlsplit, urlunsplit
13
14from . import pytestrunner
15from .actions import actions
16from .protocol import Protocol, WdspecProtocol
17
18
19here = os.path.dirname(__file__)
20
21
22def executor_kwargs(test_type, test_environment, run_info_data, **kwargs):
23    timeout_multiplier = kwargs["timeout_multiplier"]
24    if timeout_multiplier is None:
25        timeout_multiplier = 1
26
27    executor_kwargs = {"server_config": test_environment.config,
28                       "timeout_multiplier": timeout_multiplier,
29                       "debug_info": kwargs["debug_info"]}
30
31    if test_type in ("reftest", "print-reftest"):
32        executor_kwargs["screenshot_cache"] = test_environment.cache_manager.dict()
33
34    if test_type == "wdspec":
35        executor_kwargs["binary"] = kwargs.get("binary")
36        executor_kwargs["webdriver_binary"] = kwargs.get("webdriver_binary")
37        executor_kwargs["webdriver_args"] = kwargs.get("webdriver_args")
38
39    # By default the executor may try to cleanup windows after a test (to best
40    # associate any problems with the test causing them). If the user might
41    # want to view the results, however, the executor has to skip that cleanup.
42    if kwargs["pause_after_test"] or kwargs["pause_on_unexpected"]:
43        executor_kwargs["cleanup_after_test"] = False
44    executor_kwargs["debug_test"] = kwargs["debug_test"]
45    return executor_kwargs
46
47
48def strip_server(url):
49    """Remove the scheme and netloc from a url, leaving only the path and any query
50    or fragment.
51
52    url - the url to strip
53
54    e.g. http://example.org:8000/tests?id=1#2 becomes /tests?id=1#2"""
55
56    url_parts = list(urlsplit(url))
57    url_parts[0] = ""
58    url_parts[1] = ""
59    return urlunsplit(url_parts)
60
61
62class TestharnessResultConverter(object):
63    harness_codes = {0: "OK",
64                     1: "ERROR",
65                     2: "TIMEOUT",
66                     3: "PRECONDITION_FAILED"}
67
68    test_codes = {0: "PASS",
69                  1: "FAIL",
70                  2: "TIMEOUT",
71                  3: "NOTRUN",
72                  4: "PRECONDITION_FAILED"}
73
74    def __call__(self, test, result, extra=None):
75        """Convert a JSON result into a (TestResult, [SubtestResult]) tuple"""
76        result_url, status, message, stack, subtest_results = result
77        assert result_url == test.url, ("Got results from %s, expected %s" %
78                                        (result_url, test.url))
79        harness_result = test.result_cls(self.harness_codes[status], message, extra=extra, stack=stack)
80        return (harness_result,
81                [test.subtest_result_cls(st_name, self.test_codes[st_status], st_message, st_stack)
82                 for st_name, st_status, st_message, st_stack in subtest_results])
83
84
85testharness_result_converter = TestharnessResultConverter()
86
87
88def hash_screenshots(screenshots):
89    """Computes the sha1 checksum of a list of base64-encoded screenshots."""
90    return [hashlib.sha1(base64.b64decode(screenshot)).hexdigest()
91            for screenshot in screenshots]
92
93
94def _ensure_hash_in_reftest_screenshots(extra):
95    """Make sure reftest_screenshots have hashes.
96
97    Marionette internal reftest runner does not produce hashes.
98    """
99    log_data = extra.get("reftest_screenshots")
100    if not log_data:
101        return
102    for item in log_data:
103        if type(item) != dict:
104            # Skip relation strings.
105            continue
106        if "hash" not in item:
107            item["hash"] = hash_screenshots([item["screenshot"]])[0]
108
109
110def get_pages(ranges_value, total_pages):
111    """Get a set of page numbers to include in a print reftest.
112
113    :param ranges_value: Parsed page ranges as a list e.g. [[1,2], [4], [6,None]]
114    :param total_pages: Integer total number of pages in the paginated output.
115    :retval: Set containing integer page numbers to include in the comparison e.g.
116             for the example ranges value and 10 total pages this would be
117             {1,2,4,6,7,8,9,10}"""
118    if not ranges_value:
119        return set(range(1, total_pages + 1))
120
121    rv = set()
122
123    for range_limits in ranges_value:
124        if len(range_limits) == 1:
125            range_limits = [range_limits[0], range_limits[0]]
126
127        if range_limits[0] is None:
128            range_limits[0] = 1
129        if range_limits[1] is None:
130            range_limits[1] = total_pages
131
132        if range_limits[0] > total_pages:
133            continue
134        rv |= set(range(range_limits[0], range_limits[1] + 1))
135    return rv
136
137
138def reftest_result_converter(self, test, result):
139    extra = result.get("extra", {})
140    _ensure_hash_in_reftest_screenshots(extra)
141    return (test.result_cls(
142        result["status"],
143        result["message"],
144        extra=extra,
145        stack=result.get("stack")), [])
146
147
148def pytest_result_converter(self, test, data):
149    harness_data, subtest_data = data
150
151    if subtest_data is None:
152        subtest_data = []
153
154    harness_result = test.result_cls(*harness_data)
155    subtest_results = [test.subtest_result_cls(*item) for item in subtest_data]
156
157    return (harness_result, subtest_results)
158
159
160def crashtest_result_converter(self, test, result):
161    return test.result_cls(**result), []
162
163
164class ExecutorException(Exception):
165    def __init__(self, status, message):
166        self.status = status
167        self.message = message
168
169
170class TimedRunner(object):
171    def __init__(self, logger, func, protocol, url, timeout, extra_timeout):
172        self.func = func
173        self.logger = logger
174        self.result = None
175        self.protocol = protocol
176        self.url = url
177        self.timeout = timeout
178        self.extra_timeout = extra_timeout
179        self.result_flag = threading.Event()
180
181    def run(self):
182        for setup_fn in [self.set_timeout, self.before_run]:
183            err = setup_fn()
184            if err:
185                self.result = (False, err)
186                return self.result
187
188        executor = threading.Thread(target=self.run_func)
189        executor.start()
190
191        # Add twice the extra timeout since the called function is expected to
192        # wait at least self.timeout + self.extra_timeout and this gives some leeway
193        timeout = self.timeout + 2 * self.extra_timeout if self.timeout else None
194        finished = self.result_flag.wait(timeout)
195        if self.result is None:
196            if finished:
197                # flag is True unless we timeout; this *shouldn't* happen, but
198                # it can if self.run_func fails to set self.result due to raising
199                self.result = False, ("INTERNAL-ERROR", "%s.run_func didn't set a result" %
200                                      self.__class__.__name__)
201            else:
202                if self.protocol.is_alive():
203                    message = "Executor hit external timeout (this may indicate a hang)\n"
204                    # get a traceback for the current stack of the executor thread
205                    message += "".join(traceback.format_stack(sys._current_frames()[executor.ident]))
206                    self.result = False, ("EXTERNAL-TIMEOUT", message)
207                else:
208                    self.logger.info("Browser not responding, setting status to CRASH")
209                    self.result = False, ("CRASH", None)
210        elif self.result[1] is None:
211            # We didn't get any data back from the test, so check if the
212            # browser is still responsive
213            if self.protocol.is_alive():
214                self.result = False, ("INTERNAL-ERROR", None)
215            else:
216                self.logger.info("Browser not responding, setting status to CRASH")
217                self.result = False, ("CRASH", None)
218
219        return self.result
220
221    def set_timeout(self):
222        raise NotImplementedError
223
224    def before_run(self):
225        pass
226
227    def run_func(self):
228        raise NotImplementedError
229
230
231class TestExecutor(object):
232    """Abstract Base class for object that actually executes the tests in a
233    specific browser. Typically there will be a different TestExecutor
234    subclass for each test type and method of executing tests.
235
236    :param browser: ExecutorBrowser instance providing properties of the
237                    browser that will be tested.
238    :param server_config: Dictionary of wptserve server configuration of the
239                          form stored in TestEnvironment.config
240    :param timeout_multiplier: Multiplier relative to base timeout to use
241                               when setting test timeout.
242    """
243    __metaclass__ = ABCMeta
244
245    test_type = None  # type: ClassVar[str]
246    # convert_result is a class variable set to a callable converter
247    # (e.g. reftest_result_converter) converting from an instance of
248    # URLManifestItem (e.g. RefTest) + type-dependent results object +
249    # type-dependent extra data, returning a tuple of Result and list of
250    # SubtestResult. For now, any callable is accepted. TODO: Make this type
251    # stricter when more of the surrounding code is annotated.
252    convert_result = None  # type: ClassVar[Callable[..., Any]]
253    supports_testdriver = False
254    supports_jsshell = False
255    # Extra timeout to use after internal test timeout at which the harness
256    # should force a timeout
257    extra_timeout = 5  # seconds
258
259
260    def __init__(self, logger, browser, server_config, timeout_multiplier=1,
261                 debug_info=None, **kwargs):
262        self.logger = logger
263        self.runner = None
264        self.browser = browser
265        self.server_config = server_config
266        self.timeout_multiplier = timeout_multiplier
267        self.debug_info = debug_info
268        self.last_environment = {"protocol": "http",
269                                 "prefs": {}}
270        self.protocol = None  # This must be set in subclasses
271
272    def setup(self, runner):
273        """Run steps needed before tests can be started e.g. connecting to
274        browser instance
275
276        :param runner: TestRunner instance that is going to run the tests"""
277        self.runner = runner
278        if self.protocol is not None:
279            self.protocol.setup(runner)
280
281    def teardown(self):
282        """Run cleanup steps after tests have finished"""
283        if self.protocol is not None:
284            self.protocol.teardown()
285
286    def reset(self):
287        """Re-initialize internal state to facilitate repeated test execution
288        as implemented by the `--rerun` command-line argument."""
289        pass
290
291    def run_test(self, test):
292        """Run a particular test.
293
294        :param test: The test to run"""
295        try:
296            if test.environment != self.last_environment:
297                self.on_environment_change(test.environment)
298            result = self.do_test(test)
299        except Exception as e:
300            exception_string = traceback.format_exc()
301            self.logger.warning(exception_string)
302            result = self.result_from_exception(test, e, exception_string)
303
304        # log result of parent test
305        if result[0].status == "ERROR":
306            self.logger.debug(result[0].message)
307
308        self.last_environment = test.environment
309
310        self.runner.send_message("test_ended", test, result)
311
312    def server_url(self, protocol, subdomain=False):
313        scheme = "https" if protocol == "h2" else protocol
314        host = self.server_config["browser_host"]
315        if subdomain:
316            # The only supported subdomain filename flag is "www".
317            host = "{subdomain}.{host}".format(subdomain="www", host=host)
318        return "{scheme}://{host}:{port}".format(scheme=scheme, host=host,
319            port=self.server_config["ports"][protocol][0])
320
321    def test_url(self, test):
322        return urljoin(self.server_url(test.environment["protocol"],
323                                       test.subdomain), test.url)
324
325    @abstractmethod
326    def do_test(self, test):
327        """Test-type and protocol specific implementation of running a
328        specific test.
329
330        :param test: The test to run."""
331        pass
332
333    def on_environment_change(self, new_environment):
334        pass
335
336    def result_from_exception(self, test, e, exception_string):
337        if hasattr(e, "status") and e.status in test.result_cls.statuses:
338            status = e.status
339        else:
340            status = "INTERNAL-ERROR"
341        message = str(getattr(e, "message", ""))
342        if message:
343            message += "\n"
344        message += exception_string
345        return test.result_cls(status, message), []
346
347    def wait(self):
348        return self.protocol.base.wait()
349
350
351class TestharnessExecutor(TestExecutor):
352    convert_result = testharness_result_converter
353
354
355class RefTestExecutor(TestExecutor):
356    convert_result = reftest_result_converter
357    is_print = False
358
359    def __init__(self, logger, browser, server_config, timeout_multiplier=1, screenshot_cache=None,
360                 debug_info=None, **kwargs):
361        TestExecutor.__init__(self, logger, browser, server_config,
362                              timeout_multiplier=timeout_multiplier,
363                              debug_info=debug_info)
364
365        self.screenshot_cache = screenshot_cache
366
367
368class CrashtestExecutor(TestExecutor):
369    convert_result = crashtest_result_converter
370
371
372class PrintRefTestExecutor(TestExecutor):
373    convert_result = reftest_result_converter
374    is_print = True
375
376
377class RefTestImplementation(object):
378    def __init__(self, executor):
379        self.timeout_multiplier = executor.timeout_multiplier
380        self.executor = executor
381        # Cache of url:(screenshot hash, screenshot). Typically the
382        # screenshot is None, but we set this value if a test fails
383        # and the screenshot was taken from the cache so that we may
384        # retrieve the screenshot from the cache directly in the future
385        self.screenshot_cache = self.executor.screenshot_cache
386        self.message = None
387
388    def setup(self):
389        pass
390
391    def teardown(self):
392        pass
393
394    @property
395    def logger(self):
396        return self.executor.logger
397
398    def get_hash(self, test, viewport_size, dpi, page_ranges):
399        key = (test.url, viewport_size, dpi)
400
401        if key not in self.screenshot_cache:
402            success, data = self.get_screenshot_list(test, viewport_size, dpi, page_ranges)
403
404            if not success:
405                return False, data
406
407            screenshots = data
408            hash_values = hash_screenshots(data)
409            self.screenshot_cache[key] = (hash_values, screenshots)
410
411            rv = (hash_values, screenshots)
412        else:
413            rv = self.screenshot_cache[key]
414
415        self.message.append("%s %s" % (test.url, rv[0]))
416        return True, rv
417
418    def reset(self):
419        self.screenshot_cache.clear()
420
421    def check_pass(self, hashes, screenshots, urls, relation, fuzzy):
422        """Check if a test passes, and return a tuple of (pass, page_idx),
423        where page_idx is the zero-based index of the first page on which a
424        difference occurs if any, or None if there are no differences"""
425
426        assert relation in ("==", "!=")
427        lhs_hashes, rhs_hashes = hashes
428        lhs_screenshots, rhs_screenshots = screenshots
429
430        if len(lhs_hashes) != len(rhs_hashes):
431            self.logger.info("Got different number of pages")
432            return relation == "!=", None
433
434        assert len(lhs_screenshots) == len(lhs_hashes) == len(rhs_screenshots) == len(rhs_hashes)
435
436        for (page_idx, (lhs_hash,
437                        rhs_hash,
438                        lhs_screenshot,
439                        rhs_screenshot)) in enumerate(zip(lhs_hashes,
440                                                          rhs_hashes,
441                                                          lhs_screenshots,
442                                                          rhs_screenshots)):
443            comparison_screenshots = (lhs_screenshot, rhs_screenshot)
444            if not fuzzy or fuzzy == ((0, 0), (0, 0)):
445                equal = lhs_hash == rhs_hash
446                # sometimes images can have different hashes, but pixels can be identical.
447                if not equal:
448                    self.logger.info("Image hashes didn't match%s, checking pixel differences" %
449                                     ("" if len(hashes) == 1 else " on page %i" % (page_idx + 1)))
450                    max_per_channel, pixels_different = self.get_differences(comparison_screenshots,
451                                                                             urls)
452                    equal = pixels_different == 0 and max_per_channel == 0
453            else:
454                max_per_channel, pixels_different = self.get_differences(comparison_screenshots,
455                                                                         urls,
456                                                                         page_idx if len(hashes) > 1 else None)
457                allowed_per_channel, allowed_different = fuzzy
458                self.logger.info("Allowed %s pixels different, maximum difference per channel %s" %
459                                 ("-".join(str(item) for item in allowed_different),
460                                  "-".join(str(item) for item in allowed_per_channel)))
461                equal = ((pixels_different == 0 and allowed_different[0] == 0) or
462                         (max_per_channel == 0 and allowed_per_channel[0] == 0) or
463                         (allowed_per_channel[0] <= max_per_channel <= allowed_per_channel[1] and
464                          allowed_different[0] <= pixels_different <= allowed_different[1]))
465            if not equal:
466                return (False if relation == "==" else True, page_idx)
467        # All screenshots were equal within the fuzziness
468        return (True if relation == "==" else False, None)
469
470    def get_differences(self, screenshots, urls, page_idx=None):
471        from PIL import Image, ImageChops, ImageStat
472
473        lhs = Image.open(io.BytesIO(base64.b64decode(screenshots[0]))).convert("RGB")
474        rhs = Image.open(io.BytesIO(base64.b64decode(screenshots[1]))).convert("RGB")
475        self.check_if_solid_color(lhs, urls[0])
476        self.check_if_solid_color(rhs, urls[1])
477        diff = ImageChops.difference(lhs, rhs)
478        minimal_diff = diff.crop(diff.getbbox())
479        mask = minimal_diff.convert("L", dither=None)
480        stat = ImageStat.Stat(minimal_diff, mask)
481        per_channel = max(item[1] for item in stat.extrema)
482        count = stat.count[0]
483        self.logger.info("Found %s pixels different, maximum difference per channel %s%s" %
484                         (count,
485                          per_channel,
486                          "" if page_idx is None else " on page %i" % (page_idx + 1)))
487        return per_channel, count
488
489    def check_if_solid_color(self, image, url):
490        extrema = image.getextrema()
491        if all(min == max for min, max in extrema):
492            color = ''.join('%02X' % value for value, _ in extrema)
493            self.message.append("Screenshot is solid color 0x%s for %s\n" % (color, url))
494
495    def run_test(self, test):
496        viewport_size = test.viewport_size
497        dpi = test.dpi
498        page_ranges = test.page_ranges
499        self.message = []
500
501
502        # Depth-first search of reference tree, with the goal
503        # of reachings a leaf node with only pass results
504
505        stack = list(((test, item[0]), item[1]) for item in reversed(test.references))
506        page_idx = None
507        while stack:
508            hashes = [None, None]
509            screenshots = [None, None]
510            urls = [None, None]
511
512            nodes, relation = stack.pop()
513            fuzzy = self.get_fuzzy(test, nodes, relation)
514
515            for i, node in enumerate(nodes):
516                success, data = self.get_hash(node, viewport_size, dpi, page_ranges)
517                if success is False:
518                    return {"status": data[0], "message": data[1]}
519
520                hashes[i], screenshots[i] = data
521                urls[i] = node.url
522
523            is_pass, page_idx = self.check_pass(hashes, screenshots, urls, relation, fuzzy)
524            if is_pass:
525                fuzzy = self.get_fuzzy(test, nodes, relation)
526                if nodes[1].references:
527                    stack.extend(list(((nodes[1], item[0]), item[1])
528                                      for item in reversed(nodes[1].references)))
529                else:
530                    # We passed
531                    return {"status": "PASS", "message": None}
532
533        # We failed, so construct a failure message
534
535        if page_idx is None:
536            # default to outputting the last page
537            page_idx = -1
538        for i, (node, screenshot) in enumerate(zip(nodes, screenshots)):
539            if screenshot is None:
540                success, screenshot = self.retake_screenshot(node, viewport_size, dpi, page_ranges)
541                if success:
542                    screenshots[i] = screenshot
543
544        log_data = [
545            {"url": nodes[0].url,
546             "screenshot": screenshots[0][page_idx],
547             "hash": hashes[0][page_idx]},
548            relation,
549            {"url": nodes[1].url,
550             "screenshot": screenshots[1][page_idx],
551             "hash": hashes[1][page_idx]},
552        ]
553
554        return {"status": "FAIL",
555                "message": "\n".join(self.message),
556                "extra": {"reftest_screenshots": log_data}}
557
558    def get_fuzzy(self, root_test, test_nodes, relation):
559        full_key = tuple([item.url for item in test_nodes] + [relation])
560        ref_only_key = test_nodes[1].url
561
562        fuzzy_override = root_test.fuzzy_override
563        fuzzy = test_nodes[0].fuzzy
564
565        sources = [fuzzy_override, fuzzy]
566        keys = [full_key, ref_only_key, None]
567        value = None
568        for source in sources:
569            for key in keys:
570                if key in source:
571                    value = source[key]
572                    break
573            if value:
574                break
575        return value
576
577    def retake_screenshot(self, node, viewport_size, dpi, page_ranges):
578        success, data = self.get_screenshot_list(node,
579                                                 viewport_size,
580                                                 dpi,
581                                                 page_ranges)
582        if not success:
583            return False, data
584
585        key = (node.url, viewport_size, dpi)
586        hash_val, _ = self.screenshot_cache[key]
587        self.screenshot_cache[key] = hash_val, data
588        return True, data
589
590    def get_screenshot_list(self, node, viewport_size, dpi, page_ranges):
591        success, data = self.executor.screenshot(node, viewport_size, dpi, page_ranges)
592        if success and not isinstance(data, list):
593            return success, [data]
594        return success, data
595
596
597class WdspecExecutor(TestExecutor):
598    convert_result = pytest_result_converter
599    protocol_cls = WdspecProtocol  # type: ClassVar[Type[Protocol]]
600
601    def __init__(self, logger, browser, server_config, webdriver_binary,
602                 webdriver_args, timeout_multiplier=1, capabilities=None,
603                 debug_info=None, **kwargs):
604        super().__init__(logger, browser, server_config,
605                         timeout_multiplier=timeout_multiplier,
606                         debug_info=debug_info)
607        self.webdriver_binary = webdriver_binary
608        self.webdriver_args = webdriver_args
609        self.timeout_multiplier = timeout_multiplier
610        self.capabilities = capabilities
611
612    def setup(self, runner):
613        self.protocol = self.protocol_cls(self, self.browser)
614        super().setup(runner)
615
616    def is_alive(self):
617        return self.protocol.is_alive()
618
619    def on_environment_change(self, new_environment):
620        pass
621
622    def do_test(self, test):
623        timeout = test.timeout * self.timeout_multiplier + self.extra_timeout
624
625        success, data = WdspecRun(self.do_wdspec,
626                                  test.abs_path,
627                                  timeout).run()
628
629        if success:
630            return self.convert_result(test, data)
631
632        return (test.result_cls(*data), [])
633
634    def do_wdspec(self, path, timeout):
635        session_config = {"host": self.browser.host,
636                          "port": self.browser.port,
637                          "capabilities": self.capabilities,
638                          "webdriver": {
639                              "binary": self.webdriver_binary,
640                              "args": self.webdriver_args
641                          }}
642
643        return pytestrunner.run(path,
644                                self.server_config,
645                                session_config,
646                                timeout=timeout)
647
648
649class WdspecRun(object):
650    def __init__(self, func, path, timeout):
651        self.func = func
652        self.result = (None, None)
653        self.path = path
654        self.timeout = timeout
655        self.result_flag = threading.Event()
656
657    def run(self):
658        """Runs function in a thread and interrupts it if it exceeds the
659        given timeout.  Returns (True, (Result, [SubtestResult ...])) in
660        case of success, or (False, (status, extra information)) in the
661        event of failure.
662        """
663
664        executor = threading.Thread(target=self._run)
665        executor.start()
666
667        self.result_flag.wait(self.timeout)
668        if self.result[1] is None:
669            self.result = False, ("EXTERNAL-TIMEOUT", None)
670
671        return self.result
672
673    def _run(self):
674        try:
675            self.result = True, self.func(self.path, self.timeout)
676        except (socket.timeout, IOError):
677            self.result = False, ("CRASH", None)
678        except Exception as e:
679            message = getattr(e, "message")
680            if message:
681                message += "\n"
682            message += traceback.format_exc()
683            self.result = False, ("INTERNAL-ERROR", message)
684        finally:
685            self.result_flag.set()
686
687
688class CallbackHandler(object):
689    """Handle callbacks from testdriver-using tests.
690
691    The default implementation here makes sense for things that are roughly like
692    WebDriver. Things that are more different to WebDriver may need to create a
693    fully custom implementation."""
694
695    unimplemented_exc = (NotImplementedError,)  # type: ClassVar[Tuple[Type[Exception], ...]]
696
697    def __init__(self, logger, protocol, test_window):
698        self.protocol = protocol
699        self.test_window = test_window
700        self.logger = logger
701        self.callbacks = {
702            "action": self.process_action,
703            "complete": self.process_complete
704        }
705
706        self.actions = {cls.name: cls(self.logger, self.protocol) for cls in actions}
707
708    def __call__(self, result):
709        url, command, payload = result
710        self.logger.debug("Got async callback: %s" % result[1])
711        try:
712            callback = self.callbacks[command]
713        except KeyError:
714            raise ValueError("Unknown callback type %r" % result[1])
715        return callback(url, payload)
716
717    def process_complete(self, url, payload):
718        rv = [strip_server(url)] + payload
719        return True, rv
720
721    def process_action(self, url, payload):
722        action = payload["action"]
723        cmd_id = payload["id"]
724        self.logger.debug("Got action: %s" % action)
725        try:
726            action_handler = self.actions[action]
727        except KeyError:
728            raise ValueError("Unknown action %s" % action)
729        try:
730            with ActionContext(self.logger, self.protocol, payload.get("context")):
731                result = action_handler(payload)
732        except self.unimplemented_exc:
733            self.logger.warning("Action %s not implemented" % action)
734            self._send_message(cmd_id, "complete", "error", "Action %s not implemented" % action)
735        except Exception:
736            self.logger.warning("Action %s failed" % action)
737            self.logger.warning(traceback.format_exc())
738            self._send_message(cmd_id, "complete", "error")
739            raise
740        else:
741            self.logger.debug("Action %s completed with result %s" % (action, result))
742            return_message = {"result": result}
743            self._send_message(cmd_id, "complete", "success", json.dumps(return_message))
744
745        return False, None
746
747    def _send_message(self, cmd_id, message_type, status, message=None):
748        self.protocol.testdriver.send_message(cmd_id, message_type, status, message=message)
749
750
751class ActionContext(object):
752    def __init__(self, logger, protocol, context):
753        self.logger = logger
754        self.protocol = protocol
755        self.context = context
756        self.initial_window = None
757
758    def __enter__(self):
759        if self.context is None:
760            return
761
762        self.initial_window = self.protocol.base.current_window
763        self.logger.debug("Switching to window %s" % self.context)
764        self.protocol.testdriver.switch_to_window(self.context, self.initial_window)
765
766    def __exit__(self, *args):
767        if self.context is None:
768            return
769
770        self.logger.debug("Switching back to initial window")
771        self.protocol.base.set_window(self.initial_window)
772        self.initial_window = None
773