1from __future__ import absolute_import
2import os
3import subprocess
4import sys
5
6import lit.Test
7import lit.TestRunner
8import lit.util
9from lit.formats.base import TestFormat
10
11kIsWindows = sys.platform in ['win32', 'cygwin']
12
13class GoogleBenchmark(TestFormat):
14    def __init__(self, test_sub_dirs, test_suffix, benchmark_args=[]):
15        self.benchmark_args = list(benchmark_args)
16        self.test_sub_dirs = os.path.normcase(str(test_sub_dirs)).split(';')
17
18        # On Windows, assume tests will also end in '.exe'.
19        exe_suffix = str(test_suffix)
20        if kIsWindows:
21            exe_suffix += '.exe'
22
23        # Also check for .py files for testing purposes.
24        self.test_suffixes = {exe_suffix, test_suffix + '.py'}
25
26    def getBenchmarkTests(self, path, litConfig, localConfig):
27        """getBenchmarkTests(path) - [name]
28
29        Return the tests available in gtest executable.
30
31        Args:
32          path: String path to a gtest executable
33          litConfig: LitConfig instance
34          localConfig: TestingConfig instance"""
35
36        # TODO: allow splitting tests according to the "benchmark family" so
37        # the output for a single family of tests all belongs to the same test
38        # target.
39        list_test_cmd = [path, '--benchmark_list_tests']
40        try:
41            output = subprocess.check_output(list_test_cmd,
42                                             env=localConfig.environment)
43        except subprocess.CalledProcessError as exc:
44            litConfig.warning(
45                "unable to discover google-benchmarks in %r: %s. Process output: %s"
46                % (path, sys.exc_info()[1], exc.output))
47            raise StopIteration
48
49        nested_tests = []
50        for ln in output.splitlines(False):  # Don't keep newlines.
51            ln = lit.util.to_string(ln)
52            if not ln.strip():
53                continue
54
55            index = 0
56            while ln[index*2:index*2+2] == '  ':
57                index += 1
58            while len(nested_tests) > index:
59                nested_tests.pop()
60
61            ln = ln[index*2:]
62            if ln.endswith('.'):
63                nested_tests.append(ln)
64            elif any([name.startswith('DISABLED_')
65                      for name in nested_tests + [ln]]):
66                # Gtest will internally skip these tests. No need to launch a
67                # child process for it.
68                continue
69            else:
70                yield ''.join(nested_tests) + ln
71
72    def getTestsInDirectory(self, testSuite, path_in_suite,
73                            litConfig, localConfig):
74        source_path = testSuite.getSourcePath(path_in_suite)
75        for subdir in self.test_sub_dirs:
76            dir_path = os.path.join(source_path, subdir)
77            if not os.path.isdir(dir_path):
78                continue
79            for fn in lit.util.listdir_files(dir_path,
80                                             suffixes=self.test_suffixes):
81                # Discover the tests in this executable.
82                execpath = os.path.join(source_path, subdir, fn)
83                testnames = self.getBenchmarkTests(execpath, litConfig, localConfig)
84                for testname in testnames:
85                    testPath = path_in_suite + (subdir, fn, testname)
86                    yield lit.Test.Test(testSuite, testPath, localConfig,
87                                        file_path=execpath)
88
89    def execute(self, test, litConfig):
90        testPath,testName = os.path.split(test.getSourcePath())
91        while not os.path.exists(testPath):
92            # Handle GTest parametrized and typed tests, whose name includes
93            # some '/'s.
94            testPath, namePrefix = os.path.split(testPath)
95            testName = namePrefix + '/' + testName
96
97        cmd = [testPath, '--benchmark_filter=%s$' % testName ] + self.benchmark_args
98
99        if litConfig.noExecute:
100            return lit.Test.PASS, ''
101
102        try:
103            out, err, exitCode = lit.util.executeCommand(
104                cmd, env=test.config.environment,
105                timeout=litConfig.maxIndividualTestTime)
106        except lit.util.ExecuteCommandTimeoutException:
107            return (lit.Test.TIMEOUT,
108                    'Reached timeout of {} seconds'.format(
109                        litConfig.maxIndividualTestTime)
110                   )
111
112        if exitCode:
113            return lit.Test.FAIL, ('exit code: %d\n' % exitCode) + out + err
114
115        passing_test_line = testName
116        if passing_test_line not in out:
117            msg = ('Unable to find %r in google benchmark output:\n\n%s%s' %
118                   (passing_test_line, out, err))
119            return lit.Test.UNRESOLVED, msg
120
121        return lit.Test.PASS, err + out
122