1import base64
2import hashlib
3import io
4import json
5import os
6import threading
7import traceback
8import socket
9import sys
10from abc import ABCMeta, abstractmethod
11from six import text_type
12from six.moves.http_client import HTTPConnection
13from six.moves.urllib.parse import urljoin, urlsplit, urlunsplit
14
15from ..testrunner import Stop
16from .actions import actions
17from .protocol import Protocol, BaseProtocolPart
18
19here = os.path.split(__file__)[0]
20
21
22def executor_kwargs(test_type, server_config, cache_manager, run_info_data,
23                    **kwargs):
24    timeout_multiplier = kwargs["timeout_multiplier"]
25    if timeout_multiplier is None:
26        timeout_multiplier = 1
27
28    executor_kwargs = {"server_config": server_config,
29                       "timeout_multiplier": timeout_multiplier,
30                       "debug_info": kwargs["debug_info"]}
31
32    if test_type == "reftest":
33        executor_kwargs["screenshot_cache"] = cache_manager.dict()
34
35    if test_type == "wdspec":
36        executor_kwargs["binary"] = kwargs.get("binary")
37        executor_kwargs["webdriver_binary"] = kwargs.get("webdriver_binary")
38        executor_kwargs["webdriver_args"] = kwargs.get("webdriver_args")
39
40    return executor_kwargs
41
42
43def strip_server(url):
44    """Remove the scheme and netloc from a url, leaving only the path and any query
45    or fragment.
46
47    url - the url to strip
48
49    e.g. http://example.org:8000/tests?id=1#2 becomes /tests?id=1#2"""
50
51    url_parts = list(urlsplit(url))
52    url_parts[0] = ""
53    url_parts[1] = ""
54    return urlunsplit(url_parts)
55
56
57class TestharnessResultConverter(object):
58    harness_codes = {0: "OK",
59                     1: "ERROR",
60                     2: "TIMEOUT",
61                     3: "PRECONDITION_FAILED"}
62
63    test_codes = {0: "PASS",
64                  1: "FAIL",
65                  2: "TIMEOUT",
66                  3: "NOTRUN",
67                  4: "PRECONDITION_FAILED"}
68
69    def __call__(self, test, result, extra=None):
70        """Convert a JSON result into a (TestResult, [SubtestResult]) tuple"""
71        result_url, status, message, stack, subtest_results = result
72        assert result_url == test.url, ("Got results from %s, expected %s" %
73                                        (result_url, test.url))
74        harness_result = test.result_cls(self.harness_codes[status], message, extra=extra, stack=stack)
75        return (harness_result,
76                [test.subtest_result_cls(st_name, self.test_codes[st_status], st_message, st_stack)
77                 for st_name, st_status, st_message, st_stack in subtest_results])
78
79
80testharness_result_converter = TestharnessResultConverter()
81
82
83def hash_screenshot(data):
84    """Computes the sha1 checksum of a base64-encoded screenshot."""
85    return hashlib.sha1(base64.b64decode(data)).hexdigest()
86
87
88def _ensure_hash_in_reftest_screenshots(extra):
89    """Make sure reftest_screenshots have hashes.
90
91    Marionette internal reftest runner does not produce hashes.
92    """
93    log_data = extra.get("reftest_screenshots")
94    if not log_data:
95        return
96    for item in log_data:
97        if type(item) != dict:
98            # Skip relation strings.
99            continue
100        if "hash" not in item:
101            item["hash"] = hash_screenshot(item["screenshot"])
102
103
104def reftest_result_converter(self, test, result):
105    extra = result.get("extra", {})
106    _ensure_hash_in_reftest_screenshots(extra)
107    return (test.result_cls(
108        result["status"],
109        result["message"],
110        extra=extra,
111        stack=result.get("stack")), [])
112
113
114def pytest_result_converter(self, test, data):
115    harness_data, subtest_data = data
116
117    if subtest_data is None:
118        subtest_data = []
119
120    harness_result = test.result_cls(*harness_data)
121    subtest_results = [test.subtest_result_cls(*item) for item in subtest_data]
122
123    return (harness_result, subtest_results)
124
125
126def crashtest_result_converter(self, test, result):
127    return test.result_cls(**result), []
128
129class ExecutorException(Exception):
130    def __init__(self, status, message):
131        self.status = status
132        self.message = message
133
134
135class TimedRunner(object):
136    def __init__(self, logger, func, protocol, url, timeout, extra_timeout):
137        self.func = func
138        self.logger = logger
139        self.result = None
140        self.protocol = protocol
141        self.url = url
142        self.timeout = timeout
143        self.extra_timeout = extra_timeout
144        self.result_flag = threading.Event()
145
146    def run(self):
147        if self.set_timeout() is Stop:
148            return Stop
149
150        if self.before_run() is Stop:
151            return Stop
152
153        executor = threading.Thread(target=self.run_func)
154        executor.start()
155
156        # Add twice the extra timeout since the called function is expected to
157        # wait at least self.timeout + self.extra_timeout and this gives some leeway
158        timeout = self.timeout + 2 * self.extra_timeout if self.timeout else None
159        finished = self.result_flag.wait(timeout)
160        if self.result is None:
161            if finished:
162                # flag is True unless we timeout; this *shouldn't* happen, but
163                # it can if self.run_func fails to set self.result due to raising
164                self.result = False, ("INTERNAL-ERROR", "%s.run_func didn't set a result" %
165                                      self.__class__.__name__)
166            else:
167                if self.protocol.is_alive():
168                    message = "Executor hit external timeout (this may indicate a hang)\n"
169                    # get a traceback for the current stack of the executor thread
170                    message += "".join(traceback.format_stack(sys._current_frames()[executor.ident]))
171                    self.result = False, ("EXTERNAL-TIMEOUT", message)
172                else:
173                    self.logger.info("Browser not responding, setting status to CRASH")
174                    self.result = False, ("CRASH", None)
175        elif self.result[1] is None:
176            # We didn't get any data back from the test, so check if the
177            # browser is still responsive
178            if self.protocol.is_alive():
179                self.result = False, ("INTERNAL-ERROR", None)
180            else:
181                self.logger.info("Browser not responding, setting status to CRASH")
182                self.result = False, ("CRASH", None)
183
184        return self.result
185
186    def set_timeout(self):
187        raise NotImplementedError
188
189    def before_run(self):
190        pass
191
192    def run_func(self):
193        raise NotImplementedError
194
195
196class TestExecutor(object):
197    """Abstract Base class for object that actually executes the tests in a
198    specific browser. Typically there will be a different TestExecutor
199    subclass for each test type and method of executing tests.
200
201    :param browser: ExecutorBrowser instance providing properties of the
202                    browser that will be tested.
203    :param server_config: Dictionary of wptserve server configuration of the
204                          form stored in TestEnvironment.config
205    :param timeout_multiplier: Multiplier relative to base timeout to use
206                               when setting test timeout.
207    """
208    __metaclass__ = ABCMeta
209
210    test_type = None
211    convert_result = None
212    supports_testdriver = False
213    supports_jsshell = False
214    # Extra timeout to use after internal test timeout at which the harness
215    # should force a timeout
216    extra_timeout = 5  # seconds
217
218
219    def __init__(self, logger, browser, server_config, timeout_multiplier=1,
220                 debug_info=None, **kwargs):
221        self.logger = logger
222        self.runner = None
223        self.browser = browser
224        self.server_config = server_config
225        self.timeout_multiplier = timeout_multiplier
226        self.debug_info = debug_info
227        self.last_environment = {"protocol": "http",
228                                 "prefs": {}}
229        self.protocol = None  # This must be set in subclasses
230
231    def setup(self, runner):
232        """Run steps needed before tests can be started e.g. connecting to
233        browser instance
234
235        :param runner: TestRunner instance that is going to run the tests"""
236        self.runner = runner
237        if self.protocol is not None:
238            self.protocol.setup(runner)
239
240    def teardown(self):
241        """Run cleanup steps after tests have finished"""
242        if self.protocol is not None:
243            self.protocol.teardown()
244
245    def reset(self):
246        """Re-initialize internal state to facilitate repeated test execution
247        as implemented by the `--rerun` command-line argument."""
248        pass
249
250    def run_test(self, test):
251        """Run a particular test.
252
253        :param test: The test to run"""
254        if test.environment != self.last_environment:
255            self.on_environment_change(test.environment)
256        try:
257            result = self.do_test(test)
258        except Exception as e:
259            exception_string = traceback.format_exc()
260            self.logger.warning(exception_string)
261            result = self.result_from_exception(test, e, exception_string)
262
263        if result is Stop:
264            return result
265
266        # log result of parent test
267        if result[0].status == "ERROR":
268            self.logger.debug(result[0].message)
269
270        self.last_environment = test.environment
271
272        self.runner.send_message("test_ended", test, result)
273
274    def server_url(self, protocol):
275        scheme = "https" if protocol == "h2" else protocol
276        return "%s://%s:%s" % (scheme,
277                               self.server_config["browser_host"],
278                               self.server_config["ports"][protocol][0])
279
280    def test_url(self, test):
281        return urljoin(self.server_url(test.environment["protocol"]), test.url)
282
283    @abstractmethod
284    def do_test(self, test):
285        """Test-type and protocol specific implementation of running a
286        specific test.
287
288        :param test: The test to run."""
289        pass
290
291    def on_environment_change(self, new_environment):
292        pass
293
294    def result_from_exception(self, test, e, exception_string):
295        if hasattr(e, "status") and e.status in test.result_cls.statuses:
296            status = e.status
297        else:
298            status = "INTERNAL-ERROR"
299        message = text_type(getattr(e, "message", ""))
300        if message:
301            message += "\n"
302        message += exception_string
303        return test.result_cls(status, message), []
304
305    def wait(self):
306        self.protocol.base.wait()
307
308
309class TestharnessExecutor(TestExecutor):
310    convert_result = testharness_result_converter
311
312
313class RefTestExecutor(TestExecutor):
314    convert_result = reftest_result_converter
315
316    def __init__(self, logger, browser, server_config, timeout_multiplier=1, screenshot_cache=None,
317                 debug_info=None, **kwargs):
318        TestExecutor.__init__(self, logger, browser, server_config,
319                              timeout_multiplier=timeout_multiplier,
320                              debug_info=debug_info)
321
322        self.screenshot_cache = screenshot_cache
323
324
325class CrashtestExecutor(TestExecutor):
326    convert_result = crashtest_result_converter
327
328
329class RefTestImplementation(object):
330    def __init__(self, executor):
331        self.timeout_multiplier = executor.timeout_multiplier
332        self.executor = executor
333        # Cache of url:(screenshot hash, screenshot). Typically the
334        # screenshot is None, but we set this value if a test fails
335        # and the screenshot was taken from the cache so that we may
336        # retrieve the screenshot from the cache directly in the future
337        self.screenshot_cache = self.executor.screenshot_cache
338        self.message = None
339
340    def setup(self):
341        pass
342
343    def teardown(self):
344        pass
345
346    @property
347    def logger(self):
348        return self.executor.logger
349
350    def get_hash(self, test, viewport_size, dpi):
351        key = (test.url, viewport_size, dpi)
352
353        if key not in self.screenshot_cache:
354            success, data = self.executor.screenshot(test, viewport_size, dpi)
355
356            if not success:
357                return False, data
358
359            screenshot = data
360            hash_value = hash_screenshot(data)
361            self.screenshot_cache[key] = (hash_value, screenshot)
362
363            rv = (hash_value, screenshot)
364        else:
365            rv = self.screenshot_cache[key]
366
367        self.message.append("%s %s" % (test.url, rv[0]))
368        return True, rv
369
370    def reset(self):
371        self.screenshot_cache.clear()
372
373    def is_pass(self, hashes, screenshots, urls, relation, fuzzy):
374        assert relation in ("==", "!=")
375        if not fuzzy or fuzzy == ((0,0), (0,0)):
376            equal = hashes[0] == hashes[1]
377            # sometimes images can have different hashes, but pixels can be identical.
378            if not equal:
379                self.logger.info("Image hashes didn't match, checking pixel differences")
380                max_per_channel, pixels_different = self.get_differences(screenshots, urls)
381                equal = pixels_different == 0 and max_per_channel == 0
382        else:
383            max_per_channel, pixels_different = self.get_differences(screenshots, urls)
384            allowed_per_channel, allowed_different = fuzzy
385            self.logger.info("Allowed %s pixels different, maximum difference per channel %s" %
386                             ("-".join(str(item) for item in allowed_different),
387                              "-".join(str(item) for item in allowed_per_channel)))
388            equal = ((pixels_different == 0 and allowed_different[0] == 0) or
389                     (max_per_channel == 0 and allowed_per_channel[0] == 0) or
390                     (allowed_per_channel[0] <= max_per_channel <= allowed_per_channel[1] and
391                      allowed_different[0] <= pixels_different <= allowed_different[1]))
392        return equal if relation == "==" else not equal
393
394    def get_differences(self, screenshots, urls):
395        from PIL import Image, ImageChops, ImageStat
396
397        lhs = Image.open(io.BytesIO(base64.b64decode(screenshots[0]))).convert("RGB")
398        rhs = Image.open(io.BytesIO(base64.b64decode(screenshots[1]))).convert("RGB")
399        self.check_if_solid_color(lhs, urls[0])
400        self.check_if_solid_color(rhs, urls[1])
401        diff = ImageChops.difference(lhs, rhs)
402        minimal_diff = diff.crop(diff.getbbox())
403        mask = minimal_diff.convert("L", dither=None)
404        stat = ImageStat.Stat(minimal_diff, mask)
405        per_channel = max(item[1] for item in stat.extrema)
406        count = stat.count[0]
407        self.logger.info("Found %s pixels different, maximum difference per channel %s" %
408                         (count, per_channel))
409        return per_channel, count
410
411    def check_if_solid_color(self, image, url):
412        extrema = image.getextrema()
413        if all(min == max for min, max in extrema):
414            color = ''.join('%02X' % value for value, _ in extrema)
415            self.message.append("Screenshot is solid color 0x%s for %s\n" % (color, url))
416
417    def run_test(self, test):
418        viewport_size = test.viewport_size
419        dpi = test.dpi
420        self.message = []
421
422        # Depth-first search of reference tree, with the goal
423        # of reachings a leaf node with only pass results
424
425        stack = list(((test, item[0]), item[1]) for item in reversed(test.references))
426        while stack:
427            hashes = [None, None]
428            screenshots = [None, None]
429            urls = [None, None]
430
431            nodes, relation = stack.pop()
432            fuzzy = self.get_fuzzy(test, nodes, relation)
433
434            for i, node in enumerate(nodes):
435                success, data = self.get_hash(node, viewport_size, dpi)
436                if success is False:
437                    return {"status": data[0], "message": data[1]}
438
439                hashes[i], screenshots[i] = data
440                urls[i] = node.url
441
442            if self.is_pass(hashes, screenshots, urls, relation, fuzzy):
443                fuzzy = self.get_fuzzy(test, nodes, relation)
444                if nodes[1].references:
445                    stack.extend(list(((nodes[1], item[0]), item[1]) for item in reversed(nodes[1].references)))
446                else:
447                    # We passed
448                    return {"status":"PASS", "message": None}
449
450        # We failed, so construct a failure message
451
452        for i, (node, screenshot) in enumerate(zip(nodes, screenshots)):
453            if screenshot is None:
454                success, screenshot = self.retake_screenshot(node, viewport_size, dpi)
455                if success:
456                    screenshots[i] = screenshot
457
458        log_data = [
459            {"url": nodes[0].url, "screenshot": screenshots[0], "hash": hashes[0]},
460            relation,
461            {"url": nodes[1].url, "screenshot": screenshots[1], "hash": hashes[1]},
462        ]
463
464        return {"status": "FAIL",
465                "message": "\n".join(self.message),
466                "extra": {"reftest_screenshots": log_data}}
467
468    def get_fuzzy(self, root_test, test_nodes, relation):
469        full_key = tuple([item.url for item in test_nodes] + [relation])
470        ref_only_key = test_nodes[1].url
471
472        fuzzy_override = root_test.fuzzy_override
473        fuzzy = test_nodes[0].fuzzy
474
475        sources = [fuzzy_override, fuzzy]
476        keys = [full_key, ref_only_key, None]
477        value = None
478        for source in sources:
479            for key in keys:
480                if key in source:
481                    value = source[key]
482                    break
483            if value:
484                break
485        return value
486
487    def retake_screenshot(self, node, viewport_size, dpi):
488        success, data = self.executor.screenshot(node, viewport_size, dpi)
489        if not success:
490            return False, data
491
492        key = (node.url, viewport_size, dpi)
493        hash_val, _ = self.screenshot_cache[key]
494        self.screenshot_cache[key] = hash_val, data
495        return True, data
496
497
498class WdspecExecutor(TestExecutor):
499    convert_result = pytest_result_converter
500    protocol_cls = None
501
502    def __init__(self, logger, browser, server_config, webdriver_binary,
503                 webdriver_args, timeout_multiplier=1, capabilities=None,
504                 debug_info=None, **kwargs):
505        self.do_delayed_imports()
506        TestExecutor.__init__(self, logger, browser, server_config,
507                              timeout_multiplier=timeout_multiplier,
508                              debug_info=debug_info)
509        self.webdriver_binary = webdriver_binary
510        self.webdriver_args = webdriver_args
511        self.timeout_multiplier = timeout_multiplier
512        self.capabilities = capabilities
513        self.protocol = self.protocol_cls(self, browser)
514
515    def is_alive(self):
516        return self.protocol.is_alive()
517
518    def on_environment_change(self, new_environment):
519        pass
520
521    def do_test(self, test):
522        timeout = test.timeout * self.timeout_multiplier + self.extra_timeout
523
524        success, data = WdspecRun(self.do_wdspec,
525                                  self.protocol.session_config,
526                                  test.abs_path,
527                                  timeout).run()
528
529        if success:
530            return self.convert_result(test, data)
531
532        return (test.result_cls(*data), [])
533
534    def do_wdspec(self, session_config, path, timeout):
535        return pytestrunner.run(path,
536                                self.server_config,
537                                session_config,
538                                timeout=timeout)
539
540    def do_delayed_imports(self):
541        global pytestrunner
542        from . import pytestrunner
543
544
545class WdspecRun(object):
546    def __init__(self, func, session, path, timeout):
547        self.func = func
548        self.result = (None, None)
549        self.session = session
550        self.path = path
551        self.timeout = timeout
552        self.result_flag = threading.Event()
553
554    def run(self):
555        """Runs function in a thread and interrupts it if it exceeds the
556        given timeout.  Returns (True, (Result, [SubtestResult ...])) in
557        case of success, or (False, (status, extra information)) in the
558        event of failure.
559        """
560
561        executor = threading.Thread(target=self._run)
562        executor.start()
563
564        self.result_flag.wait(self.timeout)
565        if self.result[1] is None:
566            self.result = False, ("EXTERNAL-TIMEOUT", None)
567
568        return self.result
569
570    def _run(self):
571        try:
572            self.result = True, self.func(self.session, self.path, self.timeout)
573        except (socket.timeout, IOError):
574            self.result = False, ("CRASH", None)
575        except Exception as e:
576            message = getattr(e, "message")
577            if message:
578                message += "\n"
579            message += traceback.format_exc()
580            self.result = False, ("INTERNAL-ERROR", message)
581        finally:
582            self.result_flag.set()
583
584
585class ConnectionlessBaseProtocolPart(BaseProtocolPart):
586    def load(self, url):
587        pass
588
589    def execute_script(self, script, asynchronous=False):
590        pass
591
592    def set_timeout(self, timeout):
593        pass
594
595    def wait(self):
596        pass
597
598    def set_window(self, handle):
599        pass
600
601
602class ConnectionlessProtocol(Protocol):
603    implements = [ConnectionlessBaseProtocolPart]
604
605    def connect(self):
606        pass
607
608    def after_connect(self):
609        pass
610
611
612class WebDriverProtocol(Protocol):
613    server_cls = None
614
615    implements = [ConnectionlessBaseProtocolPart]
616
617    def __init__(self, executor, browser):
618        Protocol.__init__(self, executor, browser)
619        self.webdriver_binary = executor.webdriver_binary
620        self.webdriver_args = executor.webdriver_args
621        self.capabilities = self.executor.capabilities
622        self.session_config = None
623        self.server = None
624
625    def connect(self):
626        """Connect to browser via the HTTP server."""
627        self.server = self.server_cls(
628            self.logger,
629            binary=self.webdriver_binary,
630            args=self.webdriver_args)
631        self.server.start(block=False)
632        self.logger.info(
633            "WebDriver HTTP server listening at %s" % self.server.url)
634        self.session_config = {"host": self.server.host,
635                               "port": self.server.port,
636                               "capabilities": self.capabilities}
637
638    def after_connect(self):
639        pass
640
641    def teardown(self):
642        if self.server is not None and self.server.is_alive():
643            self.server.stop()
644
645    def is_alive(self):
646        """Test that the connection is still alive.
647
648        Because the remote communication happens over HTTP we need to
649        make an explicit request to the remote.  It is allowed for
650        WebDriver spec tests to not have a WebDriver session, since this
651        may be what is tested.
652
653        An HTTP request to an invalid path that results in a 404 is
654        proof enough to us that the server is alive and kicking.
655        """
656        conn = HTTPConnection(self.server.host, self.server.port)
657        conn.request("HEAD", self.server.base_path + "invalid")
658        res = conn.getresponse()
659        return res.status == 404
660
661
662class CallbackHandler(object):
663    """Handle callbacks from testdriver-using tests.
664
665    The default implementation here makes sense for things that are roughly like
666    WebDriver. Things that are more different to WebDriver may need to create a
667    fully custom implementation."""
668
669    unimplemented_exc = (NotImplementedError,)
670
671    def __init__(self, logger, protocol, test_window):
672        self.protocol = protocol
673        self.test_window = test_window
674        self.logger = logger
675        self.callbacks = {
676            "action": self.process_action,
677            "complete": self.process_complete
678        }
679
680        self.actions = {cls.name: cls(self.logger, self.protocol) for cls in actions}
681
682    def __call__(self, result):
683        url, command, payload = result
684        self.logger.debug("Got async callback: %s" % result[1])
685        try:
686            callback = self.callbacks[command]
687        except KeyError:
688            raise ValueError("Unknown callback type %r" % result[1])
689        return callback(url, payload)
690
691    def process_complete(self, url, payload):
692        rv = [strip_server(url)] + payload
693        return True, rv
694
695    def process_action(self, url, payload):
696        action = payload["action"]
697        self.logger.debug("Got action: %s" % action)
698        try:
699            action_handler = self.actions[action]
700        except KeyError:
701            raise ValueError("Unknown action %s" % action)
702        try:
703            result = action_handler(payload)
704        except self.unimplemented_exc:
705            self.logger.warning("Action %s not implemented" % action)
706            self._send_message("complete", "error", "Action %s not implemented" % action)
707        except Exception:
708            self.logger.warning("Action %s failed" % action)
709            self.logger.warning(traceback.format_exc())
710            self._send_message("complete", "error")
711            raise
712        else:
713            self.logger.debug("Action %s completed with result %s" % (action, result))
714            return_message = {"result": result}
715            self._send_message("complete", "success", json.dumps(return_message))
716
717        return False, None
718
719    def _send_message(self, message_type, status, message=None):
720        self.protocol.testdriver.send_message(message_type, status, message=message)
721