1#!/usr/bin/env python
2
3# Copyright 2019 The Cirq Developers
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""
17Runs python doctest on all python source files in the cirq directory.
18
19See also:
20    https://docs.python.org/3/library/doctest.html
21
22Usage:
23    python run_doctest.py [-q]
24
25The -q argument suppresses all output except the final result line and any error
26messages.
27"""
28
29from typing import Any, Dict, Iterable, List, Tuple
30
31import sys
32import glob
33import importlib.util
34import doctest
35
36from dev_tools import shell_tools
37from dev_tools.output_capture import OutputCapture
38
39# Bug workaround: https://github.com/python/mypy/issues/1498
40ModuleType = Any
41
42
43class Doctest:
44    def __init__(self, file_name: str, mod: ModuleType, test_globals: Dict[str, Any]):
45        self.file_name = file_name
46        self.mod = mod
47        self.test_globals = test_globals
48
49    def run(self) -> doctest.TestResults:
50        return doctest.testmod(self.mod, globs=self.test_globals, report=False, verbose=False)
51
52
53def run_tests(
54    file_paths: Iterable[str],
55    include_modules: bool = True,
56    include_local: bool = True,
57    quiet: bool = True,
58) -> doctest.TestResults:
59    """Runs code snippets from docstrings found in each file.
60
61    Args:
62        file_paths: The list of files to test.
63        include_modules: If True, the snippets can use `cirq` without explicitly
64            importing it.  E.g. `>>> cirq.LineQubit(0)`
65        include_local: If True, the file under test is imported as a python
66            module (only if the file extension is .py) and all globals defined
67            in the file may be used by the snippets.
68        quiet: Determines if progress output is shown.
69
70    Returns: A tuple with the results: (# tests failed, # tests attempted)
71    """
72
73    # Ignore calls to `plt.show()`.
74    import matplotlib.pyplot as plt
75
76    plt.switch_backend('pdf')
77
78    tests = load_tests(
79        file_paths, include_modules=include_modules, include_local=include_local, quiet=quiet
80    )
81    if not quiet:
82        print()
83    results, error_messages = exec_tests(tests, quiet=quiet)
84    if not quiet:
85        print()
86    for error in error_messages:
87        print(error)
88    return results
89
90
91def load_tests(
92    file_paths: Iterable[str],
93    include_modules: bool = True,
94    include_local: bool = True,
95    quiet: bool = True,
96) -> List[Doctest]:
97    """Prepares tests for code snippets from docstrings found in each file.
98
99    Args:
100        file_paths: The list of files to test.
101        include_modules: If True, the snippets can use `cirq` without explicitly
102            importing it.  E.g. `>>> cirq.LineQubit(0)`
103        include_local: If True, the file under test is imported as a python
104            module (only if the file extension is .py) and all globals defined
105            in the file may be used by the snippets.
106
107    Returns: A list of `Doctest` objects.
108    """
109    if not quiet:
110        try_print = print
111    else:
112        try_print = lambda *args, **kwargs: None
113    if include_modules:
114        import cirq
115        import numpy
116        import sympy
117        import pandas
118
119        base_globals = {'cirq': cirq, 'np': numpy, 'sympy': sympy, 'pd': pandas}
120    else:
121        base_globals = {}
122
123    try_print('Loading tests   ', end='')
124
125    def make_test(file_path: str) -> Doctest:
126        try_print('.', end='', flush=True)
127        mod = import_file(file_path)
128        glob = make_globals(mod)
129        return Doctest(file_path, mod, glob)
130
131    def make_globals(mod: ModuleType) -> Dict[str, Any]:
132        if include_local:
133            glob = dict(mod.__dict__)
134            glob.update(base_globals)
135            return glob
136        else:
137            return dict(base_globals)
138
139    tests = [make_test(file_path) for file_path in file_paths]
140    try_print()
141    return tests
142
143
144def exec_tests(
145    tests: Iterable[Doctest], quiet: bool = True
146) -> Tuple[doctest.TestResults, List[str]]:
147    """Runs a list of `Doctest`s and collects and returns any error messages.
148
149    Args:
150        tests: The tests to run
151
152    Returns: A tuple containing the results (# failures, # attempts) and a list
153        of the error outputs from each failing test.
154    """
155    if not quiet:
156        try_print = print
157    else:
158        try_print = lambda *args, **kwargs: None
159    try_print('Executing tests ', end='')
160
161    failed, attempted = 0, 0
162    error_messages = []
163    for test in tests:
164        out = OutputCapture()
165        with out:
166            r = test.run()
167        failed += r.failed
168        attempted += r.attempted
169        if r.failed != 0:
170            try_print('F', end='', flush=True)
171            error = shell_tools.highlight(
172                '{}\n{} failed, {} passed, {} total\n'.format(
173                    test.file_name, r.failed, r.attempted - r.failed, r.attempted
174                ),
175                shell_tools.RED,
176            )
177            error += out.content()
178            error_messages.append(error)
179        else:
180            try_print('.', end='', flush=True)
181
182    try_print()
183
184    return doctest.TestResults(failed=failed, attempted=attempted), error_messages
185
186
187def import_file(file_path: str) -> ModuleType:
188    """Finds and runs a python file as if were imported with an `import`
189    statement.
190
191    Args:
192        file_path: The file to import.
193
194    Returns: The imported module.
195    """
196    mod_name = 'cirq_doctest_module'
197    # Find and create the module
198    spec = importlib.util.spec_from_file_location(mod_name, file_path)
199    mod = importlib.util.module_from_spec(spec)
200    # Run the code in the module (but not with __name__ == '__main__')
201    sys.modules[mod_name] = mod
202    spec.loader.exec_module(mod)  # type: ignore
203    mod = sys.modules[mod_name]
204    del sys.modules[mod_name]
205    return mod
206
207
208def main():
209    quiet = len(sys.argv) >= 2 and sys.argv[1] == '-q'
210
211    file_names = glob.glob('cirq/**/*.py', recursive=True)
212    # Remove the engine client code.
213    file_names = [f for f in file_names if not f.startswith('cirq/google/engine/client/')]
214    failed, attempted = run_tests(
215        file_names, include_modules=True, include_local=False, quiet=quiet
216    )
217
218    if failed != 0:
219        print(
220            shell_tools.highlight(
221                'Failed: {} failed, {} passed, {} total'.format(
222                    failed, attempted - failed, attempted
223                ),
224                shell_tools.RED,
225            )
226        )
227        sys.exit(1)
228    else:
229        print(shell_tools.highlight(f'Passed: {attempted}', shell_tools.GREEN))
230        sys.exit(0)
231
232
233if __name__ == '__main__':
234    main()
235