1#!/usr/bin/env python 2 3# Copyright 2019 The Cirq Developers 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# https://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16""" 17Runs python doctest on all python source files in the cirq directory. 18 19See also: 20 https://docs.python.org/3/library/doctest.html 21 22Usage: 23 python run_doctest.py [-q] 24 25The -q argument suppresses all output except the final result line and any error 26messages. 27""" 28 29from typing import Any, Dict, Iterable, List, Tuple 30 31import sys 32import glob 33import importlib.util 34import doctest 35 36from dev_tools import shell_tools 37from dev_tools.output_capture import OutputCapture 38 39# Bug workaround: https://github.com/python/mypy/issues/1498 40ModuleType = Any 41 42 43class Doctest: 44 def __init__(self, file_name: str, mod: ModuleType, test_globals: Dict[str, Any]): 45 self.file_name = file_name 46 self.mod = mod 47 self.test_globals = test_globals 48 49 def run(self) -> doctest.TestResults: 50 return doctest.testmod(self.mod, globs=self.test_globals, report=False, verbose=False) 51 52 53def run_tests( 54 file_paths: Iterable[str], 55 include_modules: bool = True, 56 include_local: bool = True, 57 quiet: bool = True, 58) -> doctest.TestResults: 59 """Runs code snippets from docstrings found in each file. 60 61 Args: 62 file_paths: The list of files to test. 63 include_modules: If True, the snippets can use `cirq` without explicitly 64 importing it. E.g. `>>> cirq.LineQubit(0)` 65 include_local: If True, the file under test is imported as a python 66 module (only if the file extension is .py) and all globals defined 67 in the file may be used by the snippets. 68 quiet: Determines if progress output is shown. 69 70 Returns: A tuple with the results: (# tests failed, # tests attempted) 71 """ 72 73 # Ignore calls to `plt.show()`. 74 import matplotlib.pyplot as plt 75 76 plt.switch_backend('pdf') 77 78 tests = load_tests( 79 file_paths, include_modules=include_modules, include_local=include_local, quiet=quiet 80 ) 81 if not quiet: 82 print() 83 results, error_messages = exec_tests(tests, quiet=quiet) 84 if not quiet: 85 print() 86 for error in error_messages: 87 print(error) 88 return results 89 90 91def load_tests( 92 file_paths: Iterable[str], 93 include_modules: bool = True, 94 include_local: bool = True, 95 quiet: bool = True, 96) -> List[Doctest]: 97 """Prepares tests for code snippets from docstrings found in each file. 98 99 Args: 100 file_paths: The list of files to test. 101 include_modules: If True, the snippets can use `cirq` without explicitly 102 importing it. E.g. `>>> cirq.LineQubit(0)` 103 include_local: If True, the file under test is imported as a python 104 module (only if the file extension is .py) and all globals defined 105 in the file may be used by the snippets. 106 107 Returns: A list of `Doctest` objects. 108 """ 109 if not quiet: 110 try_print = print 111 else: 112 try_print = lambda *args, **kwargs: None 113 if include_modules: 114 import cirq 115 import numpy 116 import sympy 117 import pandas 118 119 base_globals = {'cirq': cirq, 'np': numpy, 'sympy': sympy, 'pd': pandas} 120 else: 121 base_globals = {} 122 123 try_print('Loading tests ', end='') 124 125 def make_test(file_path: str) -> Doctest: 126 try_print('.', end='', flush=True) 127 mod = import_file(file_path) 128 glob = make_globals(mod) 129 return Doctest(file_path, mod, glob) 130 131 def make_globals(mod: ModuleType) -> Dict[str, Any]: 132 if include_local: 133 glob = dict(mod.__dict__) 134 glob.update(base_globals) 135 return glob 136 else: 137 return dict(base_globals) 138 139 tests = [make_test(file_path) for file_path in file_paths] 140 try_print() 141 return tests 142 143 144def exec_tests( 145 tests: Iterable[Doctest], quiet: bool = True 146) -> Tuple[doctest.TestResults, List[str]]: 147 """Runs a list of `Doctest`s and collects and returns any error messages. 148 149 Args: 150 tests: The tests to run 151 152 Returns: A tuple containing the results (# failures, # attempts) and a list 153 of the error outputs from each failing test. 154 """ 155 if not quiet: 156 try_print = print 157 else: 158 try_print = lambda *args, **kwargs: None 159 try_print('Executing tests ', end='') 160 161 failed, attempted = 0, 0 162 error_messages = [] 163 for test in tests: 164 out = OutputCapture() 165 with out: 166 r = test.run() 167 failed += r.failed 168 attempted += r.attempted 169 if r.failed != 0: 170 try_print('F', end='', flush=True) 171 error = shell_tools.highlight( 172 '{}\n{} failed, {} passed, {} total\n'.format( 173 test.file_name, r.failed, r.attempted - r.failed, r.attempted 174 ), 175 shell_tools.RED, 176 ) 177 error += out.content() 178 error_messages.append(error) 179 else: 180 try_print('.', end='', flush=True) 181 182 try_print() 183 184 return doctest.TestResults(failed=failed, attempted=attempted), error_messages 185 186 187def import_file(file_path: str) -> ModuleType: 188 """Finds and runs a python file as if were imported with an `import` 189 statement. 190 191 Args: 192 file_path: The file to import. 193 194 Returns: The imported module. 195 """ 196 mod_name = 'cirq_doctest_module' 197 # Find and create the module 198 spec = importlib.util.spec_from_file_location(mod_name, file_path) 199 mod = importlib.util.module_from_spec(spec) 200 # Run the code in the module (but not with __name__ == '__main__') 201 sys.modules[mod_name] = mod 202 spec.loader.exec_module(mod) # type: ignore 203 mod = sys.modules[mod_name] 204 del sys.modules[mod_name] 205 return mod 206 207 208def main(): 209 quiet = len(sys.argv) >= 2 and sys.argv[1] == '-q' 210 211 file_names = glob.glob('cirq/**/*.py', recursive=True) 212 # Remove the engine client code. 213 file_names = [f for f in file_names if not f.startswith('cirq/google/engine/client/')] 214 failed, attempted = run_tests( 215 file_names, include_modules=True, include_local=False, quiet=quiet 216 ) 217 218 if failed != 0: 219 print( 220 shell_tools.highlight( 221 'Failed: {} failed, {} passed, {} total'.format( 222 failed, attempted - failed, attempted 223 ), 224 shell_tools.RED, 225 ) 226 ) 227 sys.exit(1) 228 else: 229 print(shell_tools.highlight(f'Passed: {attempted}', shell_tools.GREEN)) 230 sys.exit(0) 231 232 233if __name__ == '__main__': 234 main() 235