1# Copyright (c) 2012 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5 6import logging 7import multiprocessing 8 9from test_result import TestResults 10 11 12def _ShardedTestRunnable(test): 13 """Standalone function needed by multiprocessing.Pool.""" 14 log_format = '[' + test.device + '] # %(asctime)-15s: %(message)s' 15 if logging.getLogger().handlers: 16 logging.getLogger().handlers[0].setFormatter(logging.Formatter(log_format)) 17 else: 18 logging.basicConfig(format=log_format) 19 # Handle SystemExit here since python has a bug to exit current process 20 try: 21 return test.Run() 22 except SystemExit: 23 return TestResults() 24 25def SetTestsContainer(tests_container): 26 """Sets tests container. 27 28 multiprocessing.Queue can't be pickled across processes, so we need to set 29 this as a 'global', per process, via multiprocessing.Pool. 30 """ 31 BaseTestSharder.tests_container = tests_container 32 33 34class BaseTestSharder(object): 35 """Base class for sharding tests across multiple devices. 36 37 Args: 38 attached_devices: A list of attached devices. 39 """ 40 # See more in SetTestsContainer. 41 tests_container = None 42 43 def __init__(self, attached_devices): 44 self.attached_devices = attached_devices 45 self.retries = 1 46 self.tests = [] 47 48 def CreateShardedTestRunner(self, device, index): 49 """Factory function to create a suite-specific test runner. 50 51 Args: 52 device: Device serial where this shard will run 53 index: Index of this device in the pool. 54 55 Returns: 56 An object of BaseTestRunner type (that can provide a "Run()" method). 57 """ 58 pass 59 60 def SetupSharding(self, tests): 61 """Called before starting the shards.""" 62 pass 63 64 def OnTestsCompleted(self, test_runners, test_results): 65 """Notifies that we completed the tests.""" 66 pass 67 68 def RunShardedTests(self): 69 """Runs the tests in all connected devices. 70 71 Returns: 72 A TestResults object. 73 """ 74 logging.warning('*' * 80) 75 logging.warning('Sharding in ' + str(len(self.attached_devices)) + 76 ' devices.') 77 logging.warning('Note that the output is not synchronized.') 78 logging.warning('Look for the "Final result" banner in the end.') 79 logging.warning('*' * 80) 80 final_results = TestResults() 81 for retry in xrange(self.retries): 82 logging.warning('Try %d of %d', retry + 1, self.retries) 83 self.SetupSharding(self.tests) 84 test_runners = [] 85 for index, device in enumerate(self.attached_devices): 86 logging.warning('*' * 80) 87 logging.warning('Creating shard %d for %s', index, device) 88 logging.warning('*' * 80) 89 test_runner = self.CreateShardedTestRunner(device, index) 90 test_runners += [test_runner] 91 logging.warning('Starting...') 92 pool = multiprocessing.Pool(len(self.attached_devices), 93 SetTestsContainer, 94 [BaseTestSharder.tests_container]) 95 # map can't handle KeyboardInterrupt exception. It's a python bug. 96 # So use map_async instead. 97 async_results = pool.map_async(_ShardedTestRunnable, test_runners) 98 results_lists = async_results.get(999999) 99 test_results = TestResults.FromTestResults(results_lists) 100 if retry == self.retries - 1: 101 all_passed = final_results.ok + test_results.ok 102 final_results = test_results 103 final_results.ok = all_passed 104 break 105 else: 106 final_results.ok += test_results.ok 107 self.tests = [] 108 for t in test_results.GetAllBroken(): 109 self.tests += [t.name] 110 if not self.tests: 111 break 112 self.OnTestsCompleted(test_runners, final_results) 113 return final_results 114