1# Copyright (c) 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5
6import logging
7import multiprocessing
8
9from test_result import TestResults
10
11
12def _ShardedTestRunnable(test):
13  """Standalone function needed by multiprocessing.Pool."""
14  log_format = '[' + test.device + '] # %(asctime)-15s: %(message)s'
15  if logging.getLogger().handlers:
16    logging.getLogger().handlers[0].setFormatter(logging.Formatter(log_format))
17  else:
18    logging.basicConfig(format=log_format)
19  # Handle SystemExit here since python has a bug to exit current process
20  try:
21    return test.Run()
22  except SystemExit:
23    return TestResults()
24
25def SetTestsContainer(tests_container):
26  """Sets tests container.
27
28  multiprocessing.Queue can't be pickled across processes, so we need to set
29  this as a 'global', per process, via multiprocessing.Pool.
30  """
31  BaseTestSharder.tests_container = tests_container
32
33
34class BaseTestSharder(object):
35  """Base class for sharding tests across multiple devices.
36
37  Args:
38    attached_devices: A list of attached devices.
39  """
40  # See more in SetTestsContainer.
41  tests_container = None
42
43  def __init__(self, attached_devices):
44    self.attached_devices = attached_devices
45    self.retries = 1
46    self.tests = []
47
48  def CreateShardedTestRunner(self, device, index):
49    """Factory function to create a suite-specific test runner.
50
51    Args:
52      device: Device serial where this shard will run
53      index: Index of this device in the pool.
54
55    Returns:
56      An object of BaseTestRunner type (that can provide a "Run()" method).
57    """
58    pass
59
60  def SetupSharding(self, tests):
61    """Called before starting the shards."""
62    pass
63
64  def OnTestsCompleted(self, test_runners, test_results):
65    """Notifies that we completed the tests."""
66    pass
67
68  def RunShardedTests(self):
69    """Runs the tests in all connected devices.
70
71    Returns:
72      A TestResults object.
73    """
74    logging.warning('*' * 80)
75    logging.warning('Sharding in ' + str(len(self.attached_devices)) +
76                    ' devices.')
77    logging.warning('Note that the output is not synchronized.')
78    logging.warning('Look for the "Final result" banner in the end.')
79    logging.warning('*' * 80)
80    final_results = TestResults()
81    for retry in xrange(self.retries):
82      logging.warning('Try %d of %d', retry + 1, self.retries)
83      self.SetupSharding(self.tests)
84      test_runners = []
85      for index, device in enumerate(self.attached_devices):
86        logging.warning('*' * 80)
87        logging.warning('Creating shard %d for %s', index, device)
88        logging.warning('*' * 80)
89        test_runner = self.CreateShardedTestRunner(device, index)
90        test_runners += [test_runner]
91      logging.warning('Starting...')
92      pool = multiprocessing.Pool(len(self.attached_devices),
93                                  SetTestsContainer,
94                                  [BaseTestSharder.tests_container])
95      # map can't handle KeyboardInterrupt exception. It's a python bug.
96      # So use map_async instead.
97      async_results = pool.map_async(_ShardedTestRunnable, test_runners)
98      results_lists = async_results.get(999999)
99      test_results = TestResults.FromTestResults(results_lists)
100      if retry == self.retries - 1:
101        all_passed = final_results.ok + test_results.ok
102        final_results = test_results
103        final_results.ok = all_passed
104        break
105      else:
106        final_results.ok += test_results.ok
107        self.tests = []
108        for t in test_results.GetAllBroken():
109          self.tests += [t.name]
110        if not self.tests:
111          break
112    self.OnTestsCompleted(test_runners, final_results)
113    return final_results
114