1# Copyright 2018 The Cirq Developers 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for executable snippets in documentation. 15 16This tests runs code snippets that are executable in `.md` and `.rst` 17documentation. It covers all such files under the docs directory, as well as 18the top level README file. 19 20In addition to checking that the code executes: 21 22 * The test looks for comments of the form `# prints` and then the test 23 checks that the result of the code snippets commented out code after 24 that print statement. So if the snippet is 25 26 print('foo') 27 # prints 28 # foo 29 30 Then this checks that the print statement indeed prints 'foo'. Note that 31 leading spaces are ignored. If there are any characters after `# prints`, 32 like for instance `# prints something like` then this comparison is 33 not done. This is useful for documenting code that does print but 34 the output is non-deterministic. 35 36 * The test looks for substitutions that will be applied to the snippets 37 before running the code. This is useful if a documentation example has 38 a very long runtime, but can be made shorter by changing some variables 39 (like number of qubits or number of repetitions). For `.md` files the 40 substitution is of the form 41 42 <!---test_substitution 43 pattern 44 substitution 45 ---> 46 47 and for `.rst` the substitution is of the form 48 49 .. test-substitution:: 50 pattern 51 substitution 52 53 where pattern is the regex matching pattern (passed to re.compile) and 54 substitution is the replacement string. 55""" 56import inspect 57import sys 58from typing import Any, Dict, List, Pattern, Tuple, Iterator 59 60import os 61import pathlib 62import re 63 64import pytest 65 66import cirq 67 68 69def test_can_run_readme_code_snippets(): 70 # Get the contents of the README.md file at the project root. 71 readme_path = 'README.rst' 72 assert readme_path is not None 73 74 assert_file_has_working_code_snippets(readme_path, assume_import=False) 75 76 77def find_docs_code_snippets_paths() -> Iterator[str]: 78 docs_folder = pathlib.Path(__file__).parent 79 for filename in docs_folder.rglob('*.md'): 80 yield str(filename.relative_to(docs_folder)) 81 for filename in docs_folder.rglob('*.rst'): 82 yield str(filename.relative_to(docs_folder)) 83 84 85@pytest.mark.parametrize('path', find_docs_code_snippets_paths()) 86def test_can_run_docs_code_snippets(path): 87 docs_folder = os.path.dirname(__file__) 88 assert_file_has_working_code_snippets(os.path.join(docs_folder, path), assume_import=True) 89 90 91def find_code_snippets(pattern: str, content: str) -> List[Tuple[str, int]]: 92 matches = re.finditer(pattern, content, re.MULTILINE | re.DOTALL) 93 newlines = re.finditer("\n", content) 94 snippets = [] 95 current_line = 1 96 for match in matches: 97 for newline in newlines: 98 current_line += 1 99 if newline.start() >= match.start(): 100 snippets.append((match.group(1), current_line)) 101 break 102 return snippets 103 104 105def find_markdown_code_snippets(content: str) -> List[Tuple[str, int]]: 106 return find_code_snippets("\n```python(.*?)\n```\n", content) 107 108 109def find_markdown_test_overrides(content: str) -> List[Tuple[Pattern, str]]: 110 test_sub_text = find_code_snippets("<!---test_substitution\n(.*?)--->", content) 111 substitutions = [line.split('\n')[:-1] for line, _ in test_sub_text] 112 return [(re.compile(match), sub) for match, sub in substitutions] 113 114 115def apply_overrides(content: str, overrides: List[Tuple[Pattern, str]]) -> str: 116 override_content = content 117 for pattern, sub in overrides: 118 override_content = re.sub(pattern, sub, override_content) 119 return override_content 120 121 122def deindent_snippet(snippet: str) -> str: 123 deindented_lines = [] 124 indentation_amount = None 125 126 for line in snippet.split('\n'): 127 # The first non-empty line determines the indentation level. 128 if indentation_amount is None and re.match(r'\s*\S', line): 129 leading_whitespace = re.match(r'\s*', line) 130 if leading_whitespace: 131 indentation_amount = len(leading_whitespace.group(0)) 132 133 if line: 134 deindented_lines.append(line[indentation_amount:]) 135 else: 136 deindented_lines.append(line) 137 return '\n'.join(deindented_lines) 138 139 140def find_rst_code_snippets(content: str) -> List[Tuple[str, int]]: 141 snippets = find_code_snippets( 142 r'\n.. code-block:: python\n(?:\s+:.*?\n)*\n(.*?)(?:\n\S|\Z)', content 143 ) 144 return [(deindent_snippet(content), line_number) for content, line_number in snippets] 145 146 147def find_rst_test_overrides(content: str) -> List[Tuple[Pattern, str]]: 148 # Find ".. test-substitution::" 149 test_sub_text = find_code_snippets(r'.. test-substitution::\n(([^\n]*\n){2})', content) 150 substitutions = [line.split('\n')[:-1] for line, _ in test_sub_text] 151 return [(re.compile(match.lstrip()), sub.lstrip()) for match, sub in substitutions] 152 153 154def test_find_rst_code_snippets(): 155 snippets = find_rst_code_snippets( 156 """ 157A 3 by 3 grid of qubits using 158 159.. code-block:: python 160 161 print("hello world") 162 163The next level up. 164 165.. code-block:: python 166 :emphasize-lines: 3,5 167 168 print("hello 1") 169 170 for i in range(10): 171 print(f"hello {i}") 172 173More text. 174 175.. code-block:: python 176 177 print("last line") 178""" 179 ) 180 181 assert snippets == [ 182 ('print("hello world")\n', 4), 183 ('print("hello 1")\n\nfor i in range(10):\n print(f"hello {i}")\n', 10), 184 ('print("last line")\n', 20), 185 ] 186 187 188def test_find_rst_overrides(): 189 overrides = find_rst_test_overrides( 190 """ 191A 3 by 3 grid of qubits using 192 193.. code-block:: python 194 195 print("hello world") 196 print("golden") 197 198.. test-substitution:: 199 hello world 200 goodbye cruel world 201 202.. test-substitution:: 203 golden 204 yellow 205""" 206 ) 207 assert len(overrides) == 2 208 assert overrides[0][0].match('hello world') 209 assert overrides[1][0].match('golden') 210 assert overrides[0][1] == 'goodbye cruel world' 211 assert overrides[1][1] == 'yellow' 212 213 214def test_apply_rst_overrides(): 215 content = """ 216A 3 by 3 grid of qubits using 217 218.. code-block:: python 219 220 print("hello world") 221 print("golden") 222 223.. test-substitution:: 224 hello world 225 goodbye cruel world 226 227.. test-substitution:: 228 golden 229 yellow 230""" 231 overrides = find_rst_test_overrides(content) 232 print(overrides) 233 assert ( 234 apply_overrides(content, overrides) 235 == """ 236A 3 by 3 grid of qubits using 237 238.. code-block:: python 239 240 print("goodbye cruel world") 241 print("yellow") 242 243.. test-substitution:: 244 goodbye cruel world 245 goodbye cruel world 246 247.. test-substitution:: 248 yellow 249 yellow 250""" 251 ) 252 253 254def test_find_markdown_code_snippets(): 255 snippets = find_markdown_code_snippets( 256 """ 257A 3 by 3 grid of qubits using 258 259```python 260print("hello world") 261``` 262 263The next level up. 264 265```python 266print("hello 1") 267 268for i in range(10): 269 print(f"hello {i}") 270``` 271 272More text. 273 274```python 275print("last line") 276``` 277""" 278 ) 279 280 assert snippets == [ 281 ('\nprint("hello world")', 4), 282 ('\nprint("hello 1")\n\nfor i in range(10):\n print(f"hello {i}")', 10), 283 ('\nprint("last line")', 19), 284 ] 285 286 287def test_find_markdown_test_overrides(): 288 overrides = find_markdown_test_overrides( 289 """ 290A 3 by 3 grid of qubits using 291 292```python 293print("hello world") 294``` 295<!---test_substitution 296hello 297goodbye 298---> 299<!---test_substitution 300world 301universe 302---> 303""" 304 ) 305 306 assert len(overrides) == 2 307 assert overrides[0][0].match('hello') 308 assert overrides[1][0].match('world') 309 assert overrides[0][1] == 'goodbye' 310 assert overrides[1][1] == 'universe' 311 312 313def test_apply_overrides_markdown(): 314 content = """ 315A 3 by 3 grid of qubits using 316 317```python 318print("hello world") 319``` 320<!---test_substitution 321hello 322goodbye 323---> 324<!---test_substitution 325world 326universe 327---> 328""" 329 overrides = find_markdown_test_overrides(content) 330 assert ( 331 apply_overrides(content, overrides) 332 == """ 333A 3 by 3 grid of qubits using 334 335```python 336print("goodbye universe") 337``` 338<!---test_substitution 339goodbye 340goodbye 341---> 342<!---test_substitution 343universe 344universe 345---> 346""" 347 ) 348 349 350def assert_file_has_working_code_snippets(path: str, assume_import: bool): 351 """Checks that code snippets in a file actually run.""" 352 353 with open(path, encoding='utf-8') as f: 354 content = f.read() 355 356 # Find snippets of code, and execute them. They should finish. 357 if path.endswith('.md'): 358 overrides = find_markdown_test_overrides(content) 359 content = apply_overrides(content, overrides) 360 snippets = find_markdown_code_snippets(content) 361 else: 362 overrides = find_rst_test_overrides(content) 363 content = apply_overrides(content, overrides) 364 snippets = find_rst_code_snippets(content) 365 assert_code_snippets_run_in_sequence(snippets, assume_import) 366 367 368def assert_code_snippets_run_in_sequence(snippets: List[Tuple[str, int]], assume_import: bool): 369 """Checks that a sequence of code snippets actually run. 370 371 State is kept between snippets. Imports and variables defined in one 372 snippet will be visible in later snippets. 373 """ 374 375 state: Dict[str, Any] = {} 376 377 if assume_import: 378 exec('import cirq', state) 379 380 for content, line_number in snippets: 381 assert_code_snippet_executes_correctly(content, state, line_number) 382 383 384def _canonicalize_printed_line_chunk(chunk: str) -> str: 385 chunk = ' ' + chunk + ' ' 386 387 # Reduce trailing '.0' at end of number. 388 chunk = chunk.replace('.0-', '. -') 389 chunk = chunk.replace('.0+', '. +') 390 391 # Remove leading spacing. 392 while '[ ' in chunk: 393 chunk = chunk.replace('[ ', '[') 394 395 # Remove sign before zero. 396 chunk = chunk.replace('-0 ', '+0 ') 397 chunk = chunk.replace('-0. ', '+0. ') 398 chunk = chunk.replace('-0j', '+0j') 399 chunk = chunk.replace('-0.j', '+0.j') 400 401 # Remove possibly-redundant + sign. 402 chunk = chunk.replace(' +0. ', ' 0. ') 403 chunk = chunk.replace(' +0.j', ' 0.j') 404 405 # Remove double-spacing. 406 while ' ' in chunk: 407 chunk = chunk.replace(' ', ' ') 408 409 # Remove spaces before imaginary unit. 410 while ' j' in chunk: 411 chunk = chunk.replace(' j', 'j') 412 413 # Remove padding spaces. 414 chunk = chunk.strip() 415 416 if chunk.startswith('+'): 417 chunk = chunk[1:] 418 419 return chunk 420 421 422def canonicalize_printed_line(line: str) -> str: 423 """Remove minor variations between outputs on some systems. 424 425 Basically, numpy is extremely inconsistent about where it puts spaces and 426 minus signs on 0s. This method goes through the line looking for stuff 427 that looks like it came from numpy, and if so then strips out spacing and 428 turns signed zeroes into just zeroes. 429 430 Args: 431 line: The line to canonicalize. 432 433 Returns: 434 The canonicalized line. 435 """ 436 prev_end = 0 437 result = [] 438 for match in re.finditer(r"\[([^\]]+\.[^\]]*)\]", line): 439 start = match.start() + 1 440 end = match.end() - 1 441 result.append(line[prev_end:start]) 442 result.append(_canonicalize_printed_line_chunk(line[start:end])) 443 prev_end = end 444 result.append(line[prev_end:]) 445 return ''.join(result).rstrip() 446 447 448def test_canonicalize_printed_line(): 449 x = 'first [-0.5-0.j 0. -0.5j] then [-0. 0.]' 450 assert canonicalize_printed_line(x) == ('first [-0.5+0.j 0. -0.5j] then [0. 0.]') 451 452 a = '[-0.5-0.j 0. -0.5j 0. -0.5j -0.5+0.j ]' 453 b = '[-0.5-0. j 0. -0.5j 0. -0.5j -0.5+0. j]' 454 assert canonicalize_printed_line(a) == canonicalize_printed_line(b) 455 456 assert len({canonicalize_printed_line(e) for e in ['[2.2]', '[+2.2]', '[ 2.2]']}) == 1 457 458 assert len({canonicalize_printed_line(e) for e in ['[-0.]', '[+0.]', '[ 0.]', '[0.]']}) == 1 459 460 a = '[[ 0.+0.j 1.+0.j]' 461 b = '[[0.+0.j 1.+0.j]' 462 assert canonicalize_printed_line(a) == canonicalize_printed_line(b) 463 464 465def assert_code_snippet_executes_correctly(snippet: str, state: Dict, line_number: int = None): 466 """Executes a snippet and compares output / errors to annotations.""" 467 468 raises_annotation = re.search(r"# raises\s*(\S*)", snippet) 469 if raises_annotation is None: 470 before = snippet 471 after = None 472 expected_failure = None 473 else: 474 before = snippet[: raises_annotation.start()] 475 after = snippet[raises_annotation.start() :] 476 expected_failure = raises_annotation.group(1) 477 if not expected_failure: 478 raise AssertionError('No error type specified for # raises line.') 479 480 assert_code_snippet_runs_and_prints_expected(before, state, line_number) 481 if expected_failure is not None: 482 assert after is not None 483 assert_code_snippet_fails(after, state, expected_failure) 484 485 486def assert_code_snippet_runs_and_prints_expected( 487 snippet: str, state: Dict, line_number: int = None 488): 489 """Executes a snippet and compares captured output to annotated output.""" 490 output_lines = [] # type: List[str] 491 expected_outputs = find_expected_outputs(snippet) 492 493 def print_capture(*values, sep=' '): 494 output_lines.extend(sep.join(str(e) for e in values).split('\n')) 495 496 state['print'] = print_capture 497 try: 498 exec(snippet, state) 499 500 assert_expected_lines_present_in_order(expected_outputs, output_lines) 501 except AssertionError as ex: 502 new_msg = ex.args[0] + '\n\nIn snippet{}:\n{}'.format( 503 "" if line_number == None else " (line {})".format(line_number), _indent([snippet]) 504 ) 505 ex.args = (new_msg,) + tuple(ex.args[1:]) 506 raise 507 508 509def assert_code_snippet_fails(snippet: str, state: Dict, expected_failure_type: str): 510 try: 511 exec(snippet, state) 512 except Exception as ex: 513 actual_failure_types = [e.__name__ for e in inspect.getmro(type(ex))] 514 if expected_failure_type not in actual_failure_types: 515 raise AssertionError( 516 'Expected snippet to raise a {}, but it raised a {}.'.format( 517 expected_failure_type, ' -> '.join(actual_failure_types) 518 ) 519 ) 520 return 521 522 raise AssertionError('Expected snippet to fail, but it ran to completion.') 523 524 525def assert_expected_lines_present_in_order(expected_lines: List[str], actual_lines: List[str]): 526 """Checks that all expected lines are present. 527 528 It is permitted for there to be extra actual lines between expected lines. 529 """ 530 expected_lines = [canonicalize_printed_line(e) for e in expected_lines] 531 actual_lines = [canonicalize_printed_line(e) for e in actual_lines] 532 533 i = 0 534 for expected in expected_lines: 535 while i < len(actual_lines) and actual_lines[i] != expected: 536 i += 1 537 538 assert i < len(actual_lines), ( 539 'Missing expected line: {!r}\n' 540 '\n' 541 'Actual lines:\n' 542 '{}\n' 543 '\n' 544 'Expected lines:\n' 545 '{}\n' 546 '\n' 547 'Highlighted Differences:\n' 548 '{}\n' 549 ''.format( 550 expected, 551 _indent(actual_lines), 552 _indent(expected_lines), 553 _indent( 554 [ 555 cirq.testing.highlight_text_differences( 556 '\n'.join(actual_lines), '\n'.join(expected_lines) 557 ) 558 ] 559 ), 560 ) 561 ) 562 i += 1 563 564 565def find_expected_outputs(snippet: str) -> List[str]: 566 """Finds expected output lines within a snippet. 567 568 Expected output must be annotated with a leading '# prints'. 569 Lines below '# prints' must start with '# ' or be just '#' and not indent 570 any more than that in order to add an expected line. As soon as a line 571 breaks this pattern, expected output recording cuts off. 572 573 Adding words after '# prints' causes the expected output lines to be 574 skipped instead of included. For example, for random output say 575 '# prints something like' to avoid checking the following lines. 576 """ 577 continue_key = '# ' 578 expected = [] 579 580 printing = False 581 for line in snippet.split('\n'): 582 if printing: 583 if line.startswith(continue_key) or line == continue_key.strip(): 584 rest = line[len(continue_key) :] 585 expected.append(rest) 586 else: 587 printing = False 588 # Matches '# print', '# prints', '# print:', and '# prints:' 589 elif re.match(r'^#\s*prints?:?\s*$', line): 590 printing = True 591 592 return expected 593 594 595def _indent(lines: List[str]) -> str: 596 return '\t' + '\n'.join(lines).replace('\n', '\n\t') 597 598 599def test_find_expected_outputs(): 600 assert ( 601 find_expected_outputs( 602 """ 603# print 604# abc 605 606# def 607 """ 608 ) 609 == ['abc'] 610 ) 611 612 assert ( 613 find_expected_outputs( 614 """ 615# prints 616# abc 617 618# def 619 """ 620 ) 621 == ['abc'] 622 ) 623 624 assert ( 625 find_expected_outputs( 626 """ 627# print: 628# abc 629 630# def 631 """ 632 ) 633 == ['abc'] 634 ) 635 636 assert ( 637 find_expected_outputs( 638 """ 639#print: 640# abc 641 642# def 643 """ 644 ) 645 == ['abc'] 646 ) 647 648 assert ( 649 find_expected_outputs( 650 """ 651# prints: 652# abc 653 654# def 655 """ 656 ) 657 == ['abc'] 658 ) 659 660 assert ( 661 find_expected_outputs( 662 """ 663# prints: 664# abc 665 666# def 667 """ 668 ) 669 == ['abc'] 670 ) 671 672 assert ( 673 find_expected_outputs( 674 """ 675lorem ipsum 676 677# prints 678# abc 679 680a wondrous collection 681 682# prints 683# def 684# ghi 685 """ 686 ) 687 == [' abc', 'def', 'ghi'] 688 ) 689 690 assert ( 691 find_expected_outputs( 692 """ 693a wandering adventurer 694 695# prints something like 696# prints 697#prints 698# pants 699# trance 700 """ 701 ) 702 == [] 703 ) 704 705 706def test_assert_expected_lines_present_in_order(): 707 assert_expected_lines_present_in_order(expected_lines=[], actual_lines=[]) 708 709 assert_expected_lines_present_in_order(expected_lines=[], actual_lines=['abc']) 710 711 assert_expected_lines_present_in_order(expected_lines=['abc'], actual_lines=['abc']) 712 713 with pytest.raises(AssertionError): 714 assert_expected_lines_present_in_order(expected_lines=['abc'], actual_lines=[]) 715 716 assert_expected_lines_present_in_order( 717 expected_lines=['abc', 'def'], actual_lines=['abc', 'def'] 718 ) 719 720 assert_expected_lines_present_in_order( 721 expected_lines=['abc', 'def'], actual_lines=['abc', 'interruption', 'def'] 722 ) 723 724 with pytest.raises(AssertionError): 725 assert_expected_lines_present_in_order( 726 expected_lines=['abc', 'def'], actual_lines=['def', 'abc'] 727 ) 728 729 assert_expected_lines_present_in_order(expected_lines=['abc '], actual_lines=['abc']) 730 731 assert_expected_lines_present_in_order(expected_lines=['abc'], actual_lines=['abc ']) 732 733 734def test_assert_code_snippet_executes_correctly(): 735 assert_code_snippet_executes_correctly("a = 1", {}) 736 assert_code_snippet_executes_correctly("a = b", {'b': 1}) 737 738 s = {} 739 assert_code_snippet_executes_correctly("a = 1", s) 740 assert s['a'] == 1 741 742 with pytest.raises(NameError): 743 assert_code_snippet_executes_correctly("a = b", {}) 744 745 with pytest.raises(SyntaxError): 746 assert_code_snippet_executes_correctly("a = ;", {}) 747 748 assert_code_snippet_executes_correctly( 749 """ 750print("abc") 751# prints 752# abc 753 """, 754 {}, 755 ) 756 757 if sys.version_info[0] >= 3: # Our print capture only works in python 3. 758 with pytest.raises(AssertionError): 759 assert_code_snippet_executes_correctly( 760 """ 761print("abc") 762# prints 763# def 764 """, 765 {}, 766 ) 767 768 assert_code_snippet_executes_correctly( 769 """ 770# raises ZeroDivisionError 771a = 1 / 0 772 """, 773 {}, 774 ) 775 776 assert_code_snippet_executes_correctly( 777 """ 778# raises ArithmeticError 779a = 1 / 0 780 """, 781 {}, 782 ) 783 784 assert_code_snippet_executes_correctly( 785 """ 786# prints 123 787print("123") 788 789# raises SyntaxError 790print "abc") 791 """, 792 {}, 793 ) 794 795 with pytest.raises(AssertionError): 796 assert_code_snippet_executes_correctly( 797 """ 798# raises ValueError 799a = 1 / 0 800 """, 801 {}, 802 ) 803 804 with pytest.raises(AssertionError): 805 assert_code_snippet_executes_correctly( 806 """ 807# raises 808a = 1 809 """, 810 {}, 811 ) 812