1import base64 2import hashlib 3import io 4import json 5import os 6import threading 7import traceback 8import socket 9import sys 10from abc import ABCMeta, abstractmethod 11from http.client import HTTPConnection 12from typing import Any, Callable, ClassVar, Optional, Tuple, Type, TYPE_CHECKING 13from urllib.parse import urljoin, urlsplit, urlunsplit 14 15from .actions import actions 16from .protocol import Protocol, BaseProtocolPart 17 18if TYPE_CHECKING: 19 from ..webdriver_server import WebDriverServer 20 21here = os.path.dirname(__file__) 22 23 24def executor_kwargs(test_type, test_environment, run_info_data, **kwargs): 25 timeout_multiplier = kwargs["timeout_multiplier"] 26 if timeout_multiplier is None: 27 timeout_multiplier = 1 28 29 executor_kwargs = {"server_config": test_environment.config, 30 "timeout_multiplier": timeout_multiplier, 31 "debug_info": kwargs["debug_info"]} 32 33 if test_type in ("reftest", "print-reftest"): 34 executor_kwargs["screenshot_cache"] = test_environment.cache_manager.dict() 35 36 if test_type == "wdspec": 37 executor_kwargs["binary"] = kwargs.get("binary") 38 executor_kwargs["webdriver_binary"] = kwargs.get("webdriver_binary") 39 executor_kwargs["webdriver_args"] = kwargs.get("webdriver_args") 40 41 # By default the executor may try to cleanup windows after a test (to best 42 # associate any problems with the test causing them). If the user might 43 # want to view the results, however, the executor has to skip that cleanup. 44 if kwargs["pause_after_test"] or kwargs["pause_on_unexpected"]: 45 executor_kwargs["cleanup_after_test"] = False 46 executor_kwargs["debug_test"] = kwargs["debug_test"] 47 return executor_kwargs 48 49 50def strip_server(url): 51 """Remove the scheme and netloc from a url, leaving only the path and any query 52 or fragment. 53 54 url - the url to strip 55 56 e.g. http://example.org:8000/tests?id=1#2 becomes /tests?id=1#2""" 57 58 url_parts = list(urlsplit(url)) 59 url_parts[0] = "" 60 url_parts[1] = "" 61 return urlunsplit(url_parts) 62 63 64class TestharnessResultConverter(object): 65 harness_codes = {0: "OK", 66 1: "ERROR", 67 2: "TIMEOUT", 68 3: "PRECONDITION_FAILED"} 69 70 test_codes = {0: "PASS", 71 1: "FAIL", 72 2: "TIMEOUT", 73 3: "NOTRUN", 74 4: "PRECONDITION_FAILED"} 75 76 def __call__(self, test, result, extra=None): 77 """Convert a JSON result into a (TestResult, [SubtestResult]) tuple""" 78 result_url, status, message, stack, subtest_results = result 79 assert result_url == test.url, ("Got results from %s, expected %s" % 80 (result_url, test.url)) 81 harness_result = test.result_cls(self.harness_codes[status], message, extra=extra, stack=stack) 82 return (harness_result, 83 [test.subtest_result_cls(st_name, self.test_codes[st_status], st_message, st_stack) 84 for st_name, st_status, st_message, st_stack in subtest_results]) 85 86 87testharness_result_converter = TestharnessResultConverter() 88 89 90def hash_screenshots(screenshots): 91 """Computes the sha1 checksum of a list of base64-encoded screenshots.""" 92 return [hashlib.sha1(base64.b64decode(screenshot)).hexdigest() 93 for screenshot in screenshots] 94 95 96def _ensure_hash_in_reftest_screenshots(extra): 97 """Make sure reftest_screenshots have hashes. 98 99 Marionette internal reftest runner does not produce hashes. 100 """ 101 log_data = extra.get("reftest_screenshots") 102 if not log_data: 103 return 104 for item in log_data: 105 if type(item) != dict: 106 # Skip relation strings. 107 continue 108 if "hash" not in item: 109 item["hash"] = hash_screenshots([item["screenshot"]])[0] 110 111 112def get_pages(ranges_value, total_pages): 113 """Get a set of page numbers to include in a print reftest. 114 115 :param ranges_value: Parsed page ranges as a list e.g. [[1,2], [4], [6,None]] 116 :param total_pages: Integer total number of pages in the paginated output. 117 :retval: Set containing integer page numbers to include in the comparison e.g. 118 for the example ranges value and 10 total pages this would be 119 {1,2,4,6,7,8,9,10}""" 120 if not ranges_value: 121 return set(range(1, total_pages + 1)) 122 123 rv = set() 124 125 for range_limits in ranges_value: 126 if len(range_limits) == 1: 127 range_limits = [range_limits[0], range_limits[0]] 128 129 if range_limits[0] is None: 130 range_limits[0] = 1 131 if range_limits[1] is None: 132 range_limits[1] = total_pages 133 134 if range_limits[0] > total_pages: 135 continue 136 rv |= set(range(range_limits[0], range_limits[1] + 1)) 137 return rv 138 139 140def reftest_result_converter(self, test, result): 141 extra = result.get("extra", {}) 142 _ensure_hash_in_reftest_screenshots(extra) 143 return (test.result_cls( 144 result["status"], 145 result["message"], 146 extra=extra, 147 stack=result.get("stack")), []) 148 149 150def pytest_result_converter(self, test, data): 151 harness_data, subtest_data = data 152 153 if subtest_data is None: 154 subtest_data = [] 155 156 harness_result = test.result_cls(*harness_data) 157 subtest_results = [test.subtest_result_cls(*item) for item in subtest_data] 158 159 return (harness_result, subtest_results) 160 161 162def crashtest_result_converter(self, test, result): 163 return test.result_cls(**result), [] 164 165 166class ExecutorException(Exception): 167 def __init__(self, status, message): 168 self.status = status 169 self.message = message 170 171 172class TimedRunner(object): 173 def __init__(self, logger, func, protocol, url, timeout, extra_timeout): 174 self.func = func 175 self.logger = logger 176 self.result = None 177 self.protocol = protocol 178 self.url = url 179 self.timeout = timeout 180 self.extra_timeout = extra_timeout 181 self.result_flag = threading.Event() 182 183 def run(self): 184 for setup_fn in [self.set_timeout, self.before_run]: 185 err = setup_fn() 186 if err: 187 self.result = (False, err) 188 return self.result 189 190 executor = threading.Thread(target=self.run_func) 191 executor.start() 192 193 # Add twice the extra timeout since the called function is expected to 194 # wait at least self.timeout + self.extra_timeout and this gives some leeway 195 timeout = self.timeout + 2 * self.extra_timeout if self.timeout else None 196 finished = self.result_flag.wait(timeout) 197 if self.result is None: 198 if finished: 199 # flag is True unless we timeout; this *shouldn't* happen, but 200 # it can if self.run_func fails to set self.result due to raising 201 self.result = False, ("INTERNAL-ERROR", "%s.run_func didn't set a result" % 202 self.__class__.__name__) 203 else: 204 if self.protocol.is_alive(): 205 message = "Executor hit external timeout (this may indicate a hang)\n" 206 # get a traceback for the current stack of the executor thread 207 message += "".join(traceback.format_stack(sys._current_frames()[executor.ident])) 208 self.result = False, ("EXTERNAL-TIMEOUT", message) 209 else: 210 self.logger.info("Browser not responding, setting status to CRASH") 211 self.result = False, ("CRASH", None) 212 elif self.result[1] is None: 213 # We didn't get any data back from the test, so check if the 214 # browser is still responsive 215 if self.protocol.is_alive(): 216 self.result = False, ("INTERNAL-ERROR", None) 217 else: 218 self.logger.info("Browser not responding, setting status to CRASH") 219 self.result = False, ("CRASH", None) 220 221 return self.result 222 223 def set_timeout(self): 224 raise NotImplementedError 225 226 def before_run(self): 227 pass 228 229 def run_func(self): 230 raise NotImplementedError 231 232 233class TestExecutor(object): 234 """Abstract Base class for object that actually executes the tests in a 235 specific browser. Typically there will be a different TestExecutor 236 subclass for each test type and method of executing tests. 237 238 :param browser: ExecutorBrowser instance providing properties of the 239 browser that will be tested. 240 :param server_config: Dictionary of wptserve server configuration of the 241 form stored in TestEnvironment.config 242 :param timeout_multiplier: Multiplier relative to base timeout to use 243 when setting test timeout. 244 """ 245 __metaclass__ = ABCMeta 246 247 test_type = None # type: ClassVar[str] 248 # convert_result is a class variable set to a callable converter 249 # (e.g. reftest_result_converter) converting from an instance of 250 # URLManifestItem (e.g. RefTest) + type-dependent results object + 251 # type-dependent extra data, returning a tuple of Result and list of 252 # SubtestResult. For now, any callable is accepted. TODO: Make this type 253 # stricter when more of the surrounding code is annotated. 254 convert_result = None # type: ClassVar[Callable[..., Any]] 255 supports_testdriver = False 256 supports_jsshell = False 257 # Extra timeout to use after internal test timeout at which the harness 258 # should force a timeout 259 extra_timeout = 5 # seconds 260 261 262 def __init__(self, logger, browser, server_config, timeout_multiplier=1, 263 debug_info=None, **kwargs): 264 self.logger = logger 265 self.runner = None 266 self.browser = browser 267 self.server_config = server_config 268 self.timeout_multiplier = timeout_multiplier 269 self.debug_info = debug_info 270 self.last_environment = {"protocol": "http", 271 "prefs": {}} 272 self.protocol = None # This must be set in subclasses 273 274 def setup(self, runner): 275 """Run steps needed before tests can be started e.g. connecting to 276 browser instance 277 278 :param runner: TestRunner instance that is going to run the tests""" 279 self.runner = runner 280 if self.protocol is not None: 281 self.protocol.setup(runner) 282 283 def teardown(self): 284 """Run cleanup steps after tests have finished""" 285 if self.protocol is not None: 286 self.protocol.teardown() 287 288 def reset(self): 289 """Re-initialize internal state to facilitate repeated test execution 290 as implemented by the `--rerun` command-line argument.""" 291 pass 292 293 def run_test(self, test): 294 """Run a particular test. 295 296 :param test: The test to run""" 297 try: 298 if test.environment != self.last_environment: 299 self.on_environment_change(test.environment) 300 result = self.do_test(test) 301 except Exception as e: 302 exception_string = traceback.format_exc() 303 self.logger.warning(exception_string) 304 result = self.result_from_exception(test, e, exception_string) 305 306 # log result of parent test 307 if result[0].status == "ERROR": 308 self.logger.debug(result[0].message) 309 310 self.last_environment = test.environment 311 312 self.runner.send_message("test_ended", test, result) 313 314 def server_url(self, protocol, subdomain=False): 315 scheme = "https" if protocol == "h2" else protocol 316 host = self.server_config["browser_host"] 317 if subdomain: 318 # The only supported subdomain filename flag is "www". 319 host = "{subdomain}.{host}".format(subdomain="www", host=host) 320 return "{scheme}://{host}:{port}".format(scheme=scheme, host=host, 321 port=self.server_config["ports"][protocol][0]) 322 323 def test_url(self, test): 324 return urljoin(self.server_url(test.environment["protocol"], 325 test.subdomain), test.url) 326 327 @abstractmethod 328 def do_test(self, test): 329 """Test-type and protocol specific implementation of running a 330 specific test. 331 332 :param test: The test to run.""" 333 pass 334 335 def on_environment_change(self, new_environment): 336 pass 337 338 def result_from_exception(self, test, e, exception_string): 339 if hasattr(e, "status") and e.status in test.result_cls.statuses: 340 status = e.status 341 else: 342 status = "INTERNAL-ERROR" 343 message = str(getattr(e, "message", "")) 344 if message: 345 message += "\n" 346 message += exception_string 347 return test.result_cls(status, message), [] 348 349 def wait(self): 350 self.protocol.base.wait() 351 352 353class TestharnessExecutor(TestExecutor): 354 convert_result = testharness_result_converter 355 356 357class RefTestExecutor(TestExecutor): 358 convert_result = reftest_result_converter 359 is_print = False 360 361 def __init__(self, logger, browser, server_config, timeout_multiplier=1, screenshot_cache=None, 362 debug_info=None, **kwargs): 363 TestExecutor.__init__(self, logger, browser, server_config, 364 timeout_multiplier=timeout_multiplier, 365 debug_info=debug_info) 366 367 self.screenshot_cache = screenshot_cache 368 369 370class CrashtestExecutor(TestExecutor): 371 convert_result = crashtest_result_converter 372 373 374class PrintRefTestExecutor(TestExecutor): 375 convert_result = reftest_result_converter 376 is_print = True 377 378 379class RefTestImplementation(object): 380 def __init__(self, executor): 381 self.timeout_multiplier = executor.timeout_multiplier 382 self.executor = executor 383 # Cache of url:(screenshot hash, screenshot). Typically the 384 # screenshot is None, but we set this value if a test fails 385 # and the screenshot was taken from the cache so that we may 386 # retrieve the screenshot from the cache directly in the future 387 self.screenshot_cache = self.executor.screenshot_cache 388 self.message = None 389 390 def setup(self): 391 pass 392 393 def teardown(self): 394 pass 395 396 @property 397 def logger(self): 398 return self.executor.logger 399 400 def get_hash(self, test, viewport_size, dpi, page_ranges): 401 key = (test.url, viewport_size, dpi) 402 403 if key not in self.screenshot_cache: 404 success, data = self.get_screenshot_list(test, viewport_size, dpi, page_ranges) 405 406 if not success: 407 return False, data 408 409 screenshots = data 410 hash_values = hash_screenshots(data) 411 self.screenshot_cache[key] = (hash_values, screenshots) 412 413 rv = (hash_values, screenshots) 414 else: 415 rv = self.screenshot_cache[key] 416 417 self.message.append("%s %s" % (test.url, rv[0])) 418 return True, rv 419 420 def reset(self): 421 self.screenshot_cache.clear() 422 423 def check_pass(self, hashes, screenshots, urls, relation, fuzzy): 424 """Check if a test passes, and return a tuple of (pass, page_idx), 425 where page_idx is the zero-based index of the first page on which a 426 difference occurs if any, or None if there are no differences""" 427 428 assert relation in ("==", "!=") 429 lhs_hashes, rhs_hashes = hashes 430 lhs_screenshots, rhs_screenshots = screenshots 431 432 if len(lhs_hashes) != len(rhs_hashes): 433 self.logger.info("Got different number of pages") 434 return relation == "!=", None 435 436 assert len(lhs_screenshots) == len(lhs_hashes) == len(rhs_screenshots) == len(rhs_hashes) 437 438 for (page_idx, (lhs_hash, 439 rhs_hash, 440 lhs_screenshot, 441 rhs_screenshot)) in enumerate(zip(lhs_hashes, 442 rhs_hashes, 443 lhs_screenshots, 444 rhs_screenshots)): 445 comparison_screenshots = (lhs_screenshot, rhs_screenshot) 446 if not fuzzy or fuzzy == ((0, 0), (0, 0)): 447 equal = lhs_hash == rhs_hash 448 # sometimes images can have different hashes, but pixels can be identical. 449 if not equal: 450 self.logger.info("Image hashes didn't match%s, checking pixel differences" % 451 ("" if len(hashes) == 1 else " on page %i" % (page_idx + 1))) 452 max_per_channel, pixels_different = self.get_differences(comparison_screenshots, 453 urls) 454 equal = pixels_different == 0 and max_per_channel == 0 455 else: 456 max_per_channel, pixels_different = self.get_differences(comparison_screenshots, 457 urls, 458 page_idx if len(hashes) > 1 else None) 459 allowed_per_channel, allowed_different = fuzzy 460 self.logger.info("Allowed %s pixels different, maximum difference per channel %s" % 461 ("-".join(str(item) for item in allowed_different), 462 "-".join(str(item) for item in allowed_per_channel))) 463 equal = ((pixels_different == 0 and allowed_different[0] == 0) or 464 (max_per_channel == 0 and allowed_per_channel[0] == 0) or 465 (allowed_per_channel[0] <= max_per_channel <= allowed_per_channel[1] and 466 allowed_different[0] <= pixels_different <= allowed_different[1])) 467 if not equal: 468 return (False if relation == "==" else True, page_idx) 469 # All screenshots were equal within the fuzziness 470 return (True if relation == "==" else False, None) 471 472 def get_differences(self, screenshots, urls, page_idx=None): 473 from PIL import Image, ImageChops, ImageStat 474 475 lhs = Image.open(io.BytesIO(base64.b64decode(screenshots[0]))).convert("RGB") 476 rhs = Image.open(io.BytesIO(base64.b64decode(screenshots[1]))).convert("RGB") 477 self.check_if_solid_color(lhs, urls[0]) 478 self.check_if_solid_color(rhs, urls[1]) 479 diff = ImageChops.difference(lhs, rhs) 480 minimal_diff = diff.crop(diff.getbbox()) 481 mask = minimal_diff.convert("L", dither=None) 482 stat = ImageStat.Stat(minimal_diff, mask) 483 per_channel = max(item[1] for item in stat.extrema) 484 count = stat.count[0] 485 self.logger.info("Found %s pixels different, maximum difference per channel %s%s" % 486 (count, 487 per_channel, 488 "" if page_idx is None else " on page %i" % (page_idx + 1))) 489 return per_channel, count 490 491 def check_if_solid_color(self, image, url): 492 extrema = image.getextrema() 493 if all(min == max for min, max in extrema): 494 color = ''.join('%02X' % value for value, _ in extrema) 495 self.message.append("Screenshot is solid color 0x%s for %s\n" % (color, url)) 496 497 def run_test(self, test): 498 viewport_size = test.viewport_size 499 dpi = test.dpi 500 page_ranges = test.page_ranges 501 self.message = [] 502 503 504 # Depth-first search of reference tree, with the goal 505 # of reachings a leaf node with only pass results 506 507 stack = list(((test, item[0]), item[1]) for item in reversed(test.references)) 508 page_idx = None 509 while stack: 510 hashes = [None, None] 511 screenshots = [None, None] 512 urls = [None, None] 513 514 nodes, relation = stack.pop() 515 fuzzy = self.get_fuzzy(test, nodes, relation) 516 517 for i, node in enumerate(nodes): 518 success, data = self.get_hash(node, viewport_size, dpi, page_ranges) 519 if success is False: 520 return {"status": data[0], "message": data[1]} 521 522 hashes[i], screenshots[i] = data 523 urls[i] = node.url 524 525 is_pass, page_idx = self.check_pass(hashes, screenshots, urls, relation, fuzzy) 526 if is_pass: 527 fuzzy = self.get_fuzzy(test, nodes, relation) 528 if nodes[1].references: 529 stack.extend(list(((nodes[1], item[0]), item[1]) 530 for item in reversed(nodes[1].references))) 531 else: 532 # We passed 533 return {"status": "PASS", "message": None} 534 535 # We failed, so construct a failure message 536 537 if page_idx is None: 538 # default to outputting the last page 539 page_idx = -1 540 for i, (node, screenshot) in enumerate(zip(nodes, screenshots)): 541 if screenshot is None: 542 success, screenshot = self.retake_screenshot(node, viewport_size, dpi, page_ranges) 543 if success: 544 screenshots[i] = screenshot 545 546 log_data = [ 547 {"url": nodes[0].url, 548 "screenshot": screenshots[0][page_idx], 549 "hash": hashes[0][page_idx]}, 550 relation, 551 {"url": nodes[1].url, 552 "screenshot": screenshots[1][page_idx], 553 "hash": hashes[1][page_idx]}, 554 ] 555 556 return {"status": "FAIL", 557 "message": "\n".join(self.message), 558 "extra": {"reftest_screenshots": log_data}} 559 560 def get_fuzzy(self, root_test, test_nodes, relation): 561 full_key = tuple([item.url for item in test_nodes] + [relation]) 562 ref_only_key = test_nodes[1].url 563 564 fuzzy_override = root_test.fuzzy_override 565 fuzzy = test_nodes[0].fuzzy 566 567 sources = [fuzzy_override, fuzzy] 568 keys = [full_key, ref_only_key, None] 569 value = None 570 for source in sources: 571 for key in keys: 572 if key in source: 573 value = source[key] 574 break 575 if value: 576 break 577 return value 578 579 def retake_screenshot(self, node, viewport_size, dpi, page_ranges): 580 success, data = self.get_screenshot_list(node, 581 viewport_size, 582 dpi, 583 page_ranges) 584 if not success: 585 return False, data 586 587 key = (node.url, viewport_size, dpi) 588 hash_val, _ = self.screenshot_cache[key] 589 self.screenshot_cache[key] = hash_val, data 590 return True, data 591 592 def get_screenshot_list(self, node, viewport_size, dpi, page_ranges): 593 success, data = self.executor.screenshot(node, viewport_size, dpi, page_ranges) 594 if success and not isinstance(data, list): 595 return success, [data] 596 return success, data 597 598 599class WdspecExecutor(TestExecutor): 600 convert_result = pytest_result_converter 601 protocol_cls = None # type: ClassVar[Type[Protocol]] 602 603 def __init__(self, logger, browser, server_config, webdriver_binary, 604 webdriver_args, timeout_multiplier=1, capabilities=None, 605 debug_info=None, environ=None, **kwargs): 606 self.do_delayed_imports() 607 TestExecutor.__init__(self, logger, browser, server_config, 608 timeout_multiplier=timeout_multiplier, 609 debug_info=debug_info) 610 self.webdriver_binary = webdriver_binary 611 self.webdriver_args = webdriver_args 612 self.timeout_multiplier = timeout_multiplier 613 self.capabilities = capabilities 614 self.environ = environ if environ is not None else {} 615 self.output_handler_kwargs = None 616 self.output_handler_start_kwargs = None 617 618 def setup(self, runner): 619 self.protocol = self.protocol_cls(self, self.browser) 620 super().setup(runner) 621 622 def is_alive(self): 623 return self.protocol.is_alive() 624 625 def on_environment_change(self, new_environment): 626 pass 627 628 def do_test(self, test): 629 timeout = test.timeout * self.timeout_multiplier + self.extra_timeout 630 631 success, data = WdspecRun(self.do_wdspec, 632 self.protocol.session_config, 633 test.abs_path, 634 timeout).run() 635 636 if success: 637 return self.convert_result(test, data) 638 639 return (test.result_cls(*data), []) 640 641 def do_wdspec(self, session_config, path, timeout): 642 return pytestrunner.run(path, 643 self.server_config, 644 session_config, 645 timeout=timeout) 646 647 def do_delayed_imports(self): 648 global pytestrunner 649 from . import pytestrunner 650 651 652class WdspecRun(object): 653 def __init__(self, func, session, path, timeout): 654 self.func = func 655 self.result = (None, None) 656 self.session = session 657 self.path = path 658 self.timeout = timeout 659 self.result_flag = threading.Event() 660 661 def run(self): 662 """Runs function in a thread and interrupts it if it exceeds the 663 given timeout. Returns (True, (Result, [SubtestResult ...])) in 664 case of success, or (False, (status, extra information)) in the 665 event of failure. 666 """ 667 668 executor = threading.Thread(target=self._run) 669 executor.start() 670 671 self.result_flag.wait(self.timeout) 672 if self.result[1] is None: 673 self.result = False, ("EXTERNAL-TIMEOUT", None) 674 675 return self.result 676 677 def _run(self): 678 try: 679 self.result = True, self.func(self.session, self.path, self.timeout) 680 except (socket.timeout, IOError): 681 self.result = False, ("CRASH", None) 682 except Exception as e: 683 message = getattr(e, "message") 684 if message: 685 message += "\n" 686 message += traceback.format_exc() 687 self.result = False, ("INTERNAL-ERROR", message) 688 finally: 689 self.result_flag.set() 690 691 692class ConnectionlessBaseProtocolPart(BaseProtocolPart): 693 def load(self, url): 694 pass 695 696 def execute_script(self, script, asynchronous=False): 697 pass 698 699 def set_timeout(self, timeout): 700 pass 701 702 def wait(self): 703 pass 704 705 def set_window(self, handle): 706 pass 707 708 def window_handles(self): 709 return [] 710 711 712class ConnectionlessProtocol(Protocol): 713 implements = [ConnectionlessBaseProtocolPart] 714 715 def connect(self): 716 pass 717 718 def after_connect(self): 719 pass 720 721 722class WdspecProtocol(Protocol): 723 server_cls = None # type: ClassVar[Optional[Type[WebDriverServer]]] 724 725 implements = [ConnectionlessBaseProtocolPart] 726 727 def __init__(self, executor, browser): 728 Protocol.__init__(self, executor, browser) 729 self.webdriver_binary = executor.webdriver_binary 730 self.webdriver_args = executor.webdriver_args 731 self.capabilities = self.executor.capabilities 732 self.session_config = None 733 self.server = None 734 self.environ = os.environ.copy() 735 self.environ.update(executor.environ) 736 self.output_handler_kwargs = executor.output_handler_kwargs 737 self.output_handler_start_kwargs = executor.output_handler_start_kwargs 738 739 def connect(self): 740 """Connect to browser via the HTTP server.""" 741 self.server = self.server_cls( 742 self.logger, 743 binary=self.webdriver_binary, 744 args=self.webdriver_args, 745 env=self.environ) 746 self.server.start(block=False, 747 output_handler_kwargs=self.output_handler_kwargs, 748 output_handler_start_kwargs=self.output_handler_start_kwargs) 749 self.logger.info( 750 "WebDriver HTTP server listening at %s" % self.server.url) 751 self.session_config = {"host": self.server.host, 752 "port": self.server.port, 753 "capabilities": self.capabilities} 754 755 def after_connect(self): 756 pass 757 758 def teardown(self): 759 if self.server is not None and self.server.is_alive(): 760 self.server.stop() 761 762 def is_alive(self): 763 """Test that the connection is still alive. 764 765 Because the remote communication happens over HTTP we need to 766 make an explicit request to the remote. It is allowed for 767 WebDriver spec tests to not have a WebDriver session, since this 768 may be what is tested. 769 770 An HTTP request to an invalid path that results in a 404 is 771 proof enough to us that the server is alive and kicking. 772 """ 773 conn = HTTPConnection(self.server.host, self.server.port) 774 conn.request("HEAD", self.server.base_path + "invalid") 775 res = conn.getresponse() 776 return res.status == 404 777 778 779class CallbackHandler(object): 780 """Handle callbacks from testdriver-using tests. 781 782 The default implementation here makes sense for things that are roughly like 783 WebDriver. Things that are more different to WebDriver may need to create a 784 fully custom implementation.""" 785 786 unimplemented_exc = (NotImplementedError,) # type: ClassVar[Tuple[Type[Exception], ...]] 787 788 def __init__(self, logger, protocol, test_window): 789 self.protocol = protocol 790 self.test_window = test_window 791 self.logger = logger 792 self.callbacks = { 793 "action": self.process_action, 794 "complete": self.process_complete 795 } 796 797 self.actions = {cls.name: cls(self.logger, self.protocol) for cls in actions} 798 799 def __call__(self, result): 800 url, command, payload = result 801 self.logger.debug("Got async callback: %s" % result[1]) 802 try: 803 callback = self.callbacks[command] 804 except KeyError: 805 raise ValueError("Unknown callback type %r" % result[1]) 806 return callback(url, payload) 807 808 def process_complete(self, url, payload): 809 rv = [strip_server(url)] + payload 810 return True, rv 811 812 def process_action(self, url, payload): 813 action = payload["action"] 814 cmd_id = payload["id"] 815 self.logger.debug("Got action: %s" % action) 816 try: 817 action_handler = self.actions[action] 818 except KeyError: 819 raise ValueError("Unknown action %s" % action) 820 try: 821 with ActionContext(self.logger, self.protocol, payload.get("context")): 822 result = action_handler(payload) 823 except self.unimplemented_exc: 824 self.logger.warning("Action %s not implemented" % action) 825 self._send_message(cmd_id, "complete", "error", "Action %s not implemented" % action) 826 except Exception: 827 self.logger.warning("Action %s failed" % action) 828 self.logger.warning(traceback.format_exc()) 829 self._send_message(cmd_id, "complete", "error") 830 raise 831 else: 832 self.logger.debug("Action %s completed with result %s" % (action, result)) 833 return_message = {"result": result} 834 self._send_message(cmd_id, "complete", "success", json.dumps(return_message)) 835 836 return False, None 837 838 def _send_message(self, cmd_id, message_type, status, message=None): 839 self.protocol.testdriver.send_message(cmd_id, message_type, status, message=message) 840 841 842class ActionContext(object): 843 def __init__(self, logger, protocol, context): 844 self.logger = logger 845 self.protocol = protocol 846 self.context = context 847 self.initial_window = None 848 849 def __enter__(self): 850 if self.context is None: 851 return 852 853 self.initial_window = self.protocol.base.current_window 854 self.logger.debug("Switching to window %s" % self.context) 855 self.protocol.testdriver.switch_to_window(self.context, self.initial_window) 856 857 def __exit__(self, *args): 858 if self.context is None: 859 return 860 861 self.logger.debug("Switching back to initial window") 862 self.protocol.base.set_window(self.initial_window) 863 self.initial_window = None 864