1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for executable snippets in documentation.
15
16This tests runs code snippets that are executable in `.md` and `.rst`
17documentation. It covers all such files under the docs directory, as well as
18the top level README file.
19
20In addition to checking that the code executes:
21
22    * The test looks for comments of the form `# prints` and then the test
23      checks that the result of the code snippets commented out code after
24      that print statement.  So if the snippet is
25
26          print('foo')
27          # prints
28          # foo
29
30      Then this checks that the print statement indeed prints 'foo'.  Note that
31      leading spaces are ignored.  If there are any characters after `# prints`,
32      like for instance `# prints something like` then this comparison is
33      not done. This is useful for documenting code that does print but
34      the output is non-deterministic.
35
36    * The test looks for substitutions that will be applied to the snippets
37      before running the code. This is useful if a documentation example has
38      a very long runtime, but can be made shorter by changing some variables
39      (like number of qubits or number of repetitions).  For `.md` files the
40      substitution is of the form
41
42            <!---test_substitution
43            pattern
44            substitution
45            --->
46
47      and for `.rst` the substitution is of the form
48
49            .. test-substitution::
50                pattern
51                substitution
52
53      where pattern is the regex matching pattern (passed to re.compile) and
54      substitution is the replacement string.
55"""
56import inspect
57import sys
58from typing import Any, Dict, List, Pattern, Tuple, Iterator
59
60import os
61import pathlib
62import re
63
64import pytest
65
66import cirq
67
68
69def test_can_run_readme_code_snippets():
70    # Get the contents of the README.md file at the project root.
71    readme_path = 'README.rst'
72    assert readme_path is not None
73
74    assert_file_has_working_code_snippets(readme_path, assume_import=False)
75
76
77def find_docs_code_snippets_paths() -> Iterator[str]:
78    docs_folder = pathlib.Path(__file__).parent
79    for filename in docs_folder.rglob('*.md'):
80        yield str(filename.relative_to(docs_folder))
81    for filename in docs_folder.rglob('*.rst'):
82        yield str(filename.relative_to(docs_folder))
83
84
85@pytest.mark.parametrize('path', find_docs_code_snippets_paths())
86def test_can_run_docs_code_snippets(path):
87    docs_folder = os.path.dirname(__file__)
88    assert_file_has_working_code_snippets(os.path.join(docs_folder, path), assume_import=True)
89
90
91def find_code_snippets(pattern: str, content: str) -> List[Tuple[str, int]]:
92    matches = re.finditer(pattern, content, re.MULTILINE | re.DOTALL)
93    newlines = re.finditer("\n", content)
94    snippets = []
95    current_line = 1
96    for match in matches:
97        for newline in newlines:
98            current_line += 1
99            if newline.start() >= match.start():
100                snippets.append((match.group(1), current_line))
101                break
102    return snippets
103
104
105def find_markdown_code_snippets(content: str) -> List[Tuple[str, int]]:
106    return find_code_snippets("\n```python(.*?)\n```\n", content)
107
108
109def find_markdown_test_overrides(content: str) -> List[Tuple[Pattern, str]]:
110    test_sub_text = find_code_snippets("<!---test_substitution\n(.*?)--->", content)
111    substitutions = [line.split('\n')[:-1] for line, _ in test_sub_text]
112    return [(re.compile(match), sub) for match, sub in substitutions]
113
114
115def apply_overrides(content: str, overrides: List[Tuple[Pattern, str]]) -> str:
116    override_content = content
117    for pattern, sub in overrides:
118        override_content = re.sub(pattern, sub, override_content)
119    return override_content
120
121
122def deindent_snippet(snippet: str) -> str:
123    deindented_lines = []
124    indentation_amount = None
125
126    for line in snippet.split('\n'):
127        # The first non-empty line determines the indentation level.
128        if indentation_amount is None and re.match(r'\s*\S', line):
129            leading_whitespace = re.match(r'\s*', line)
130            if leading_whitespace:
131                indentation_amount = len(leading_whitespace.group(0))
132
133        if line:
134            deindented_lines.append(line[indentation_amount:])
135        else:
136            deindented_lines.append(line)
137    return '\n'.join(deindented_lines)
138
139
140def find_rst_code_snippets(content: str) -> List[Tuple[str, int]]:
141    snippets = find_code_snippets(
142        r'\n.. code-block:: python\n(?:\s+:.*?\n)*\n(.*?)(?:\n\S|\Z)', content
143    )
144    return [(deindent_snippet(content), line_number) for content, line_number in snippets]
145
146
147def find_rst_test_overrides(content: str) -> List[Tuple[Pattern, str]]:
148    # Find ".. test-substitution::"
149    test_sub_text = find_code_snippets(r'.. test-substitution::\n(([^\n]*\n){2})', content)
150    substitutions = [line.split('\n')[:-1] for line, _ in test_sub_text]
151    return [(re.compile(match.lstrip()), sub.lstrip()) for match, sub in substitutions]
152
153
154def test_find_rst_code_snippets():
155    snippets = find_rst_code_snippets(
156        """
157A 3 by 3 grid of qubits using
158
159.. code-block:: python
160
161    print("hello world")
162
163The next level up.
164
165.. code-block:: python
166    :emphasize-lines: 3,5
167
168    print("hello 1")
169
170    for i in range(10):
171        print(f"hello {i}")
172
173More text.
174
175.. code-block:: python
176
177    print("last line")
178"""
179    )
180
181    assert snippets == [
182        ('print("hello world")\n', 4),
183        ('print("hello 1")\n\nfor i in range(10):\n    print(f"hello {i}")\n', 10),
184        ('print("last line")\n', 20),
185    ]
186
187
188def test_find_rst_overrides():
189    overrides = find_rst_test_overrides(
190        """
191A 3 by 3 grid of qubits using
192
193.. code-block:: python
194
195    print("hello world")
196    print("golden")
197
198.. test-substitution::
199    hello world
200    goodbye cruel world
201
202.. test-substitution::
203    golden
204    yellow
205"""
206    )
207    assert len(overrides) == 2
208    assert overrides[0][0].match('hello world')
209    assert overrides[1][0].match('golden')
210    assert overrides[0][1] == 'goodbye cruel world'
211    assert overrides[1][1] == 'yellow'
212
213
214def test_apply_rst_overrides():
215    content = """
216A 3 by 3 grid of qubits using
217
218.. code-block:: python
219
220    print("hello world")
221    print("golden")
222
223.. test-substitution::
224    hello world
225    goodbye cruel world
226
227.. test-substitution::
228    golden
229    yellow
230"""
231    overrides = find_rst_test_overrides(content)
232    print(overrides)
233    assert (
234        apply_overrides(content, overrides)
235        == """
236A 3 by 3 grid of qubits using
237
238.. code-block:: python
239
240    print("goodbye cruel world")
241    print("yellow")
242
243.. test-substitution::
244    goodbye cruel world
245    goodbye cruel world
246
247.. test-substitution::
248    yellow
249    yellow
250"""
251    )
252
253
254def test_find_markdown_code_snippets():
255    snippets = find_markdown_code_snippets(
256        """
257A 3 by 3 grid of qubits using
258
259```python
260print("hello world")
261```
262
263The next level up.
264
265```python
266print("hello 1")
267
268for i in range(10):
269    print(f"hello {i}")
270```
271
272More text.
273
274```python
275print("last line")
276```
277"""
278    )
279
280    assert snippets == [
281        ('\nprint("hello world")', 4),
282        ('\nprint("hello 1")\n\nfor i in range(10):\n    print(f"hello {i}")', 10),
283        ('\nprint("last line")', 19),
284    ]
285
286
287def test_find_markdown_test_overrides():
288    overrides = find_markdown_test_overrides(
289        """
290A 3 by 3 grid of qubits using
291
292```python
293print("hello world")
294```
295<!---test_substitution
296hello
297goodbye
298--->
299<!---test_substitution
300world
301universe
302--->
303"""
304    )
305
306    assert len(overrides) == 2
307    assert overrides[0][0].match('hello')
308    assert overrides[1][0].match('world')
309    assert overrides[0][1] == 'goodbye'
310    assert overrides[1][1] == 'universe'
311
312
313def test_apply_overrides_markdown():
314    content = """
315A 3 by 3 grid of qubits using
316
317```python
318print("hello world")
319```
320<!---test_substitution
321hello
322goodbye
323--->
324<!---test_substitution
325world
326universe
327--->
328"""
329    overrides = find_markdown_test_overrides(content)
330    assert (
331        apply_overrides(content, overrides)
332        == """
333A 3 by 3 grid of qubits using
334
335```python
336print("goodbye universe")
337```
338<!---test_substitution
339goodbye
340goodbye
341--->
342<!---test_substitution
343universe
344universe
345--->
346"""
347    )
348
349
350def assert_file_has_working_code_snippets(path: str, assume_import: bool):
351    """Checks that code snippets in a file actually run."""
352
353    with open(path, encoding='utf-8') as f:
354        content = f.read()
355
356    # Find snippets of code, and execute them. They should finish.
357    if path.endswith('.md'):
358        overrides = find_markdown_test_overrides(content)
359        content = apply_overrides(content, overrides)
360        snippets = find_markdown_code_snippets(content)
361    else:
362        overrides = find_rst_test_overrides(content)
363        content = apply_overrides(content, overrides)
364        snippets = find_rst_code_snippets(content)
365    assert_code_snippets_run_in_sequence(snippets, assume_import)
366
367
368def assert_code_snippets_run_in_sequence(snippets: List[Tuple[str, int]], assume_import: bool):
369    """Checks that a sequence of code snippets actually run.
370
371    State is kept between snippets. Imports and variables defined in one
372    snippet will be visible in later snippets.
373    """
374
375    state: Dict[str, Any] = {}
376
377    if assume_import:
378        exec('import cirq', state)
379
380    for content, line_number in snippets:
381        assert_code_snippet_executes_correctly(content, state, line_number)
382
383
384def _canonicalize_printed_line_chunk(chunk: str) -> str:
385    chunk = ' ' + chunk + ' '
386
387    # Reduce trailing '.0' at end of number.
388    chunk = chunk.replace('.0-', '. -')
389    chunk = chunk.replace('.0+', '. +')
390
391    # Remove leading spacing.
392    while '[ ' in chunk:
393        chunk = chunk.replace('[ ', '[')
394
395    # Remove sign before zero.
396    chunk = chunk.replace('-0 ', '+0 ')
397    chunk = chunk.replace('-0. ', '+0. ')
398    chunk = chunk.replace('-0j', '+0j')
399    chunk = chunk.replace('-0.j', '+0.j')
400
401    # Remove possibly-redundant + sign.
402    chunk = chunk.replace(' +0. ', ' 0. ')
403    chunk = chunk.replace(' +0.j', ' 0.j')
404
405    # Remove double-spacing.
406    while '  ' in chunk:
407        chunk = chunk.replace('  ', ' ')
408
409    # Remove spaces before imaginary unit.
410    while ' j' in chunk:
411        chunk = chunk.replace(' j', 'j')
412
413    # Remove padding spaces.
414    chunk = chunk.strip()
415
416    if chunk.startswith('+'):
417        chunk = chunk[1:]
418
419    return chunk
420
421
422def canonicalize_printed_line(line: str) -> str:
423    """Remove minor variations between outputs on some systems.
424
425    Basically, numpy is extremely inconsistent about where it puts spaces and
426    minus signs on 0s. This method goes through the line looking for stuff
427    that looks like it came from numpy, and if so then strips out spacing and
428    turns signed zeroes into just zeroes.
429
430    Args:
431        line: The line to canonicalize.
432
433    Returns:
434        The canonicalized line.
435    """
436    prev_end = 0
437    result = []
438    for match in re.finditer(r"\[([^\]]+\.[^\]]*)\]", line):
439        start = match.start() + 1
440        end = match.end() - 1
441        result.append(line[prev_end:start])
442        result.append(_canonicalize_printed_line_chunk(line[start:end]))
443        prev_end = end
444    result.append(line[prev_end:])
445    return ''.join(result).rstrip()
446
447
448def test_canonicalize_printed_line():
449    x = 'first [-0.5-0.j   0. -0.5j] then [-0.  0.]'
450    assert canonicalize_printed_line(x) == ('first [-0.5+0.j 0. -0.5j] then [0. 0.]')
451
452    a = '[-0.5-0.j   0. -0.5j  0. -0.5j -0.5+0.j ]'
453    b = '[-0.5-0. j  0. -0.5j  0. -0.5j -0.5+0. j]'
454    assert canonicalize_printed_line(a) == canonicalize_printed_line(b)
455
456    assert len({canonicalize_printed_line(e) for e in ['[2.2]', '[+2.2]', '[ 2.2]']}) == 1
457
458    assert len({canonicalize_printed_line(e) for e in ['[-0.]', '[+0.]', '[ 0.]', '[0.]']}) == 1
459
460    a = '[[ 0.+0.j 1.+0.j]'
461    b = '[[0.+0.j 1.+0.j]'
462    assert canonicalize_printed_line(a) == canonicalize_printed_line(b)
463
464
465def assert_code_snippet_executes_correctly(snippet: str, state: Dict, line_number: int = None):
466    """Executes a snippet and compares output / errors to annotations."""
467
468    raises_annotation = re.search(r"# raises\s*(\S*)", snippet)
469    if raises_annotation is None:
470        before = snippet
471        after = None
472        expected_failure = None
473    else:
474        before = snippet[: raises_annotation.start()]
475        after = snippet[raises_annotation.start() :]
476        expected_failure = raises_annotation.group(1)
477        if not expected_failure:
478            raise AssertionError('No error type specified for # raises line.')
479
480    assert_code_snippet_runs_and_prints_expected(before, state, line_number)
481    if expected_failure is not None:
482        assert after is not None
483        assert_code_snippet_fails(after, state, expected_failure)
484
485
486def assert_code_snippet_runs_and_prints_expected(
487    snippet: str, state: Dict, line_number: int = None
488):
489    """Executes a snippet and compares captured output to annotated output."""
490    output_lines = []  # type: List[str]
491    expected_outputs = find_expected_outputs(snippet)
492
493    def print_capture(*values, sep=' '):
494        output_lines.extend(sep.join(str(e) for e in values).split('\n'))
495
496    state['print'] = print_capture
497    try:
498        exec(snippet, state)
499
500        assert_expected_lines_present_in_order(expected_outputs, output_lines)
501    except AssertionError as ex:
502        new_msg = ex.args[0] + '\n\nIn snippet{}:\n{}'.format(
503            "" if line_number == None else " (line {})".format(line_number), _indent([snippet])
504        )
505        ex.args = (new_msg,) + tuple(ex.args[1:])
506        raise
507
508
509def assert_code_snippet_fails(snippet: str, state: Dict, expected_failure_type: str):
510    try:
511        exec(snippet, state)
512    except Exception as ex:
513        actual_failure_types = [e.__name__ for e in inspect.getmro(type(ex))]
514        if expected_failure_type not in actual_failure_types:
515            raise AssertionError(
516                'Expected snippet to raise a {}, but it raised a {}.'.format(
517                    expected_failure_type, ' -> '.join(actual_failure_types)
518                )
519            )
520        return
521
522    raise AssertionError('Expected snippet to fail, but it ran to completion.')
523
524
525def assert_expected_lines_present_in_order(expected_lines: List[str], actual_lines: List[str]):
526    """Checks that all expected lines are present.
527
528    It is permitted for there to be extra actual lines between expected lines.
529    """
530    expected_lines = [canonicalize_printed_line(e) for e in expected_lines]
531    actual_lines = [canonicalize_printed_line(e) for e in actual_lines]
532
533    i = 0
534    for expected in expected_lines:
535        while i < len(actual_lines) and actual_lines[i] != expected:
536            i += 1
537
538        assert i < len(actual_lines), (
539            'Missing expected line: {!r}\n'
540            '\n'
541            'Actual lines:\n'
542            '{}\n'
543            '\n'
544            'Expected lines:\n'
545            '{}\n'
546            '\n'
547            'Highlighted Differences:\n'
548            '{}\n'
549            ''.format(
550                expected,
551                _indent(actual_lines),
552                _indent(expected_lines),
553                _indent(
554                    [
555                        cirq.testing.highlight_text_differences(
556                            '\n'.join(actual_lines), '\n'.join(expected_lines)
557                        )
558                    ]
559                ),
560            )
561        )
562        i += 1
563
564
565def find_expected_outputs(snippet: str) -> List[str]:
566    """Finds expected output lines within a snippet.
567
568    Expected output must be annotated with a leading '# prints'.
569    Lines below '# prints' must start with '# ' or be just '#' and not indent
570    any more than that in order to add an expected line. As soon as a line
571    breaks this pattern, expected output recording cuts off.
572
573    Adding words after '# prints' causes the expected output lines to be
574    skipped instead of included. For example, for random output say
575    '# prints something like' to avoid checking the following lines.
576    """
577    continue_key = '# '
578    expected = []
579
580    printing = False
581    for line in snippet.split('\n'):
582        if printing:
583            if line.startswith(continue_key) or line == continue_key.strip():
584                rest = line[len(continue_key) :]
585                expected.append(rest)
586            else:
587                printing = False
588        # Matches '# print', '# prints', '# print:', and '# prints:'
589        elif re.match(r'^#\s*prints?:?\s*$', line):
590            printing = True
591
592    return expected
593
594
595def _indent(lines: List[str]) -> str:
596    return '\t' + '\n'.join(lines).replace('\n', '\n\t')
597
598
599def test_find_expected_outputs():
600    assert (
601        find_expected_outputs(
602            """
603# print
604# abc
605
606# def
607    """
608        )
609        == ['abc']
610    )
611
612    assert (
613        find_expected_outputs(
614            """
615# prints
616# abc
617
618# def
619    """
620        )
621        == ['abc']
622    )
623
624    assert (
625        find_expected_outputs(
626            """
627# print:
628# abc
629
630# def
631    """
632        )
633        == ['abc']
634    )
635
636    assert (
637        find_expected_outputs(
638            """
639#print:
640# abc
641
642# def
643    """
644        )
645        == ['abc']
646    )
647
648    assert (
649        find_expected_outputs(
650            """
651# prints:
652# abc
653
654# def
655    """
656        )
657        == ['abc']
658    )
659
660    assert (
661        find_expected_outputs(
662            """
663# prints:
664# abc
665
666# def
667    """
668        )
669        == ['abc']
670    )
671
672    assert (
673        find_expected_outputs(
674            """
675lorem ipsum
676
677# prints
678#   abc
679
680a wondrous collection
681
682# prints
683# def
684# ghi
685    """
686        )
687        == ['  abc', 'def', 'ghi']
688    )
689
690    assert (
691        find_expected_outputs(
692            """
693a wandering adventurer
694
695# prints something like
696#  prints
697#prints
698# pants
699# trance
700    """
701        )
702        == []
703    )
704
705
706def test_assert_expected_lines_present_in_order():
707    assert_expected_lines_present_in_order(expected_lines=[], actual_lines=[])
708
709    assert_expected_lines_present_in_order(expected_lines=[], actual_lines=['abc'])
710
711    assert_expected_lines_present_in_order(expected_lines=['abc'], actual_lines=['abc'])
712
713    with pytest.raises(AssertionError):
714        assert_expected_lines_present_in_order(expected_lines=['abc'], actual_lines=[])
715
716    assert_expected_lines_present_in_order(
717        expected_lines=['abc', 'def'], actual_lines=['abc', 'def']
718    )
719
720    assert_expected_lines_present_in_order(
721        expected_lines=['abc', 'def'], actual_lines=['abc', 'interruption', 'def']
722    )
723
724    with pytest.raises(AssertionError):
725        assert_expected_lines_present_in_order(
726            expected_lines=['abc', 'def'], actual_lines=['def', 'abc']
727        )
728
729    assert_expected_lines_present_in_order(expected_lines=['abc    '], actual_lines=['abc'])
730
731    assert_expected_lines_present_in_order(expected_lines=['abc'], actual_lines=['abc      '])
732
733
734def test_assert_code_snippet_executes_correctly():
735    assert_code_snippet_executes_correctly("a = 1", {})
736    assert_code_snippet_executes_correctly("a = b", {'b': 1})
737
738    s = {}
739    assert_code_snippet_executes_correctly("a = 1", s)
740    assert s['a'] == 1
741
742    with pytest.raises(NameError):
743        assert_code_snippet_executes_correctly("a = b", {})
744
745    with pytest.raises(SyntaxError):
746        assert_code_snippet_executes_correctly("a = ;", {})
747
748    assert_code_snippet_executes_correctly(
749        """
750print("abc")
751# prints
752# abc
753        """,
754        {},
755    )
756
757    if sys.version_info[0] >= 3:  # Our print capture only works in python 3.
758        with pytest.raises(AssertionError):
759            assert_code_snippet_executes_correctly(
760                """
761print("abc")
762# prints
763# def
764                """,
765                {},
766            )
767
768    assert_code_snippet_executes_correctly(
769        """
770# raises ZeroDivisionError
771a = 1 / 0
772    """,
773        {},
774    )
775
776    assert_code_snippet_executes_correctly(
777        """
778# raises ArithmeticError
779a = 1 / 0
780        """,
781        {},
782    )
783
784    assert_code_snippet_executes_correctly(
785        """
786# prints 123
787print("123")
788
789# raises SyntaxError
790print "abc")
791        """,
792        {},
793    )
794
795    with pytest.raises(AssertionError):
796        assert_code_snippet_executes_correctly(
797            """
798# raises ValueError
799a = 1 / 0
800            """,
801            {},
802        )
803
804    with pytest.raises(AssertionError):
805        assert_code_snippet_executes_correctly(
806            """
807# raises
808a = 1
809            """,
810            {},
811        )
812