1#!/usr/bin/env python
2
3# Copyright 2019 The Cirq Developers
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     https://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Locates imports that violate cirq's submodule dependencies.
17
18Specifically, this test treats the modules as a tree structure where `cirq` is
19the root, each submodule is a node and each python file is a leaf node.  While
20a node (module) is in the process of being imported, it is not allowed to import
21nodes for the first time other than its children.  If a module was imported
22earlier by `cirq.__init__`, it may be imported.  This is currently only enforced
23for the first level of submodules under cirq, not sub-submodules.
24
25Usage:
26    dev_tools/import_test.py [-h] [--time] [--others]
27
28    optional arguments:
29      -h, --help  show this help message and exit
30      --time      print a report of the modules that took the longest to import
31      --others    also track packages other than cirq and print when they are
32                  imported
33"""
34
35from typing import List
36
37import argparse
38import collections
39import os.path
40import subprocess
41import sys
42import time
43
44parser = argparse.ArgumentParser(
45    description="Locates imports that violate cirq's submodule dependencies."
46)
47parser.add_argument(
48    '--time',
49    action='store_true',
50    help='print a report of the modules that took the longest to import',
51)
52parser.add_argument(
53    '--others',
54    action='store_true',
55    help='also track packages other than cirq and print when they are imported',
56)
57
58
59def verify_import_tree(depth: int = 1, track_others: bool = False, timeit: bool = False) -> bool:
60    """Locates imports that violate cirq's submodule dependencies by
61    instrumenting python import machinery then importing cirq.
62
63    Logs when each submodule (up to the given depth) begins and ends executing
64    during import and prints an error when any import within a submodule causes
65    a neighboring module to be imported for the first time.  The indent
66    pattern of the printed output will match the module tree structure if the
67    imports are all valid.  Otherwise an error is printed indicating the
68    location of the invalid import.
69
70    Output for valid imports:
71        Start cirq
72          ...
73          Start cirq.study
74          End   cirq.study
75          Start cirq.circuits
76          End   cirq.circuits
77          Start cirq.schedules
78          End   cirq.schedules
79          ...
80        End   cirq
81
82    Output for an invalid import in `cirq/circuits/circuit.py`:
83        Start cirq
84        ...
85          Start cirq.study
86          End   cirq.study
87          Start cirq.circuits
88        ERROR: cirq.circuits.circuit imported cirq.vis
89            Start cirq.vis
90            End   cirq.vis
91            ...  # Possibly more errors caused by the first.
92          End   cirq.circuits
93          Start cirq.schedules
94          End   cirq.schedules
95          ...
96        End   cirq
97
98        Invalid import: cirq.circuits.circuit imported cirq.vis
99
100    Args:
101        depth: How deep in the module tree to verify.  If depth is 1, verifies
102            that submodules of cirq like cirq.ops doesn't import cirq.circuit.
103            If depth is 2, verifies that submodules and sub-submodules like
104            cirq.ops.raw_types doesn't import cirq.ops.common_gates or
105            cirq.circuit.
106        track_others: If True, logs where cirq first imports an external package
107            in addition to logging when cirq modules are imported.
108        timeit: Measure the import time of cirq and each submodule and print a
109            report of the worst.  Includes times for external packages used by
110            cirq if `track_others` is True.
111
112    Returns:
113        True is no import issues, False otherwise.
114    """
115    fail_list = []
116    start_times = {}
117    load_times = {}
118    current_path: List[str] = []
119    currently_running_paths: List[List[str]] = [[]]
120    import_depth = 0
121    indent = ' ' * 2
122
123    def wrap_module(module):
124        nonlocal import_depth
125        start_times[module.__name__] = time.perf_counter()
126
127        path = module.__name__.split('.')
128        if path[0] != 'cirq':
129            if len(path) == 1:
130                print(f'{indent * import_depth}Other {module.__name__}')
131            return module
132
133        currently_running_paths.append(path)
134        if len(path) == len(current_path) + 1 and path[:-1] == current_path:
135            # Move down in tree
136            current_path.append(path[-1])
137        else:
138            # Jump somewhere else in the tree
139            handle_error(currently_running_paths[-2], path)
140            current_path[:] = path
141        if len(path) <= depth + 1:
142            print(f'{indent * import_depth}Start {module.__name__}')
143            import_depth += 1
144
145        return module
146
147    def after_exec(module):
148        nonlocal import_depth
149        load_times[module.__name__] = time.perf_counter() - start_times[module.__name__]
150
151        path = module.__name__.split('.')
152        if path[0] != 'cirq':
153            return
154
155        assert path == currently_running_paths.pop(), 'Unexpected import state'
156        if len(path) <= depth + 1:
157            import_depth -= 1
158            print(f'{indent * import_depth}End   {module.__name__}')
159        if path == current_path:
160            # No submodules were here
161            current_path.pop()
162        elif len(path) == len(current_path) - 1 and path == current_path[:-1]:
163            # Move up in tree
164            current_path.pop()
165        else:
166            # Jump somewhere else in the tree
167            current_path[:] = path[:-1]
168
169    def handle_error(import_from, import_to):
170        if import_from[: depth + 1] != import_to[: depth + 1]:
171            msg = f"{'.'.join(import_from)} imported {'.'.join(import_to)}"
172            fail_list.append(msg)
173            print(f'ERROR: {msg}')
174
175    # Import wrap_module_executions without importing cirq
176    orig_path = list(sys.path)
177    project_dir = os.path.dirname(os.path.dirname(__file__))
178    cirq_dir = os.path.join(project_dir, 'cirq')
179    sys.path.append(cirq_dir)  # Put cirq/_import.py in the path.
180    from cirq._import import wrap_module_executions  # type: ignore
181
182    sys.path[:] = orig_path  # Restore the path.
183
184    sys.path.append(project_dir)  # Ensure the cirq package is in the path.
185    # note that with the cirq.google injection we do change the metapath
186    with wrap_module_executions('' if track_others else 'cirq', wrap_module, after_exec, False):
187        # Import cirq with instrumentation
188        import cirq  # pylint: disable=unused-import
189
190    sys.path[:] = orig_path  # Restore the path.
191
192    if fail_list:
193        print()
194        # Only print the first because later errors are often caused by the
195        # first and not as helpful.
196        print(f'Invalid import: {fail_list[0]}')
197
198    if timeit:
199        worst_loads = collections.Counter(load_times).most_common(15)
200        print()
201        print('Worst load times:')
202        for name, dt in worst_loads:
203            print(f'{dt:.3f}  {name}')
204
205    return not fail_list
206
207
208FAIL_EXIT_CODE = 65
209
210
211def test_no_circular_imports():
212    """Runs the test in a subprocess because cirq has already been imported
213    before in an earlier test but this test needs to control the import process.
214    """
215    status = subprocess.call([sys.executable, __file__])
216    if status == FAIL_EXIT_CODE:
217        # coverage: ignore
218        raise Exception('Invalid import. See captured output for details.')
219    elif status != 0:
220        # coverage: ignore
221        raise RuntimeError('Error in subprocess')
222
223
224if __name__ == '__main__':
225    args = parser.parse_args()
226    success = verify_import_tree(track_others=args.others, timeit=args.time)
227    sys.exit(0 if success else FAIL_EXIT_CODE)
228