1# Copyright (c) Facebook, Inc. and its affiliates. 2# 3# This source code is licensed under the MIT license found in the 4# LICENSE file in the root directory of this source tree. 5 6from ast import literal_eval 7from textwrap import dedent 8from typing import List, Set 9from unittest.mock import Mock 10 11import libcst as cst 12import libcst.matchers as m 13from libcst.matchers import ( 14 MatcherDecoratableTransformer, 15 MatcherDecoratableVisitor, 16 call_if_inside, 17 call_if_not_inside, 18 leave, 19 visit, 20) 21from libcst.testing.utils import UnitTest 22 23 24def fixture(code: str) -> cst.Module: 25 return cst.parse_module(dedent(code)) 26 27 28class MatchersGatingDecoratorsTest(UnitTest): 29 def test_call_if_inside_transform_simple(self) -> None: 30 # Set up a simple visitor with a call_if_inside decorator. 31 class TestVisitor(MatcherDecoratableTransformer): 32 def __init__(self) -> None: 33 super().__init__() 34 self.visits: List[str] = [] 35 self.leaves: List[str] = [] 36 37 @call_if_inside(m.FunctionDef(m.Name("foo"))) 38 def visit_SimpleString(self, node: cst.SimpleString) -> None: 39 self.visits.append(node.value) 40 41 @call_if_inside(m.FunctionDef()) 42 def leave_SimpleString( 43 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 44 ) -> cst.SimpleString: 45 self.leaves.append(updated_node.value) 46 return updated_node 47 48 # Parse a module and verify we visited correctly. 49 module = fixture( 50 """ 51 a = "foo" 52 b = "bar" 53 54 def foo() -> None: 55 return "baz" 56 57 def bar() -> None: 58 return "foobar" 59 """ 60 ) 61 visitor = TestVisitor() 62 module.visit(visitor) 63 64 # We should have only visited a select number of nodes. 65 self.assertEqual(visitor.visits, ['"baz"']) 66 self.assertEqual(visitor.leaves, ['"baz"', '"foobar"']) 67 68 def test_call_if_inside_verify_original_transform(self) -> None: 69 # Set up a simple visitor with a call_if_inside decorator. 70 class TestVisitor(MatcherDecoratableTransformer): 71 def __init__(self) -> None: 72 super().__init__() 73 self.func_visits: List[str] = [] 74 self.str_visits: List[str] = [] 75 76 @call_if_inside(m.FunctionDef(m.Name("foo"))) 77 def visit_SimpleString(self, node: cst.SimpleString) -> None: 78 self.str_visits.append(node.value) 79 80 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: 81 self.func_visits.append(node.name.value) 82 83 # Parse a module and verify we visited correctly. 84 module = fixture( 85 """ 86 a = "foo" 87 b = "bar" 88 89 def foo() -> None: 90 return "baz" 91 92 def bar() -> None: 93 return "foobar" 94 """ 95 ) 96 visitor = TestVisitor() 97 module.visit(visitor) 98 99 # We should have only visited a select number of nodes. 100 self.assertEqual(visitor.func_visits, ["foo", "bar"]) 101 self.assertEqual(visitor.str_visits, ['"baz"']) 102 103 def test_call_if_inside_collect_simple(self) -> None: 104 # Set up a simple visitor with a call_if_inside decorator. 105 class TestVisitor(MatcherDecoratableVisitor): 106 def __init__(self) -> None: 107 super().__init__() 108 self.visits: List[str] = [] 109 self.leaves: List[str] = [] 110 111 @call_if_inside(m.FunctionDef(m.Name("foo"))) 112 def visit_SimpleString(self, node: cst.SimpleString) -> None: 113 self.visits.append(node.value) 114 115 @call_if_inside(m.FunctionDef()) 116 def leave_SimpleString(self, original_node: cst.SimpleString) -> None: 117 self.leaves.append(original_node.value) 118 119 # Parse a module and verify we visited correctly. 120 module = fixture( 121 """ 122 a = "foo" 123 b = "bar" 124 125 def foo() -> None: 126 return "baz" 127 128 def bar() -> None: 129 return "foobar" 130 """ 131 ) 132 visitor = TestVisitor() 133 module.visit(visitor) 134 135 # We should have only visited a select number of nodes. 136 self.assertEqual(visitor.visits, ['"baz"']) 137 self.assertEqual(visitor.leaves, ['"baz"', '"foobar"']) 138 139 def test_call_if_inside_verify_original_collect(self) -> None: 140 # Set up a simple visitor with a call_if_inside decorator. 141 class TestVisitor(MatcherDecoratableVisitor): 142 def __init__(self) -> None: 143 super().__init__() 144 self.func_visits: List[str] = [] 145 self.str_visits: List[str] = [] 146 147 @call_if_inside(m.FunctionDef(m.Name("foo"))) 148 def visit_SimpleString(self, node: cst.SimpleString) -> None: 149 self.str_visits.append(node.value) 150 151 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: 152 self.func_visits.append(node.name.value) 153 154 # Parse a module and verify we visited correctly. 155 module = fixture( 156 """ 157 a = "foo" 158 b = "bar" 159 160 def foo() -> None: 161 return "baz" 162 163 def bar() -> None: 164 return "foobar" 165 """ 166 ) 167 visitor = TestVisitor() 168 module.visit(visitor) 169 170 # We should have only visited a select number of nodes. 171 self.assertEqual(visitor.func_visits, ["foo", "bar"]) 172 self.assertEqual(visitor.str_visits, ['"baz"']) 173 174 def test_multiple_visitors_collect(self) -> None: 175 # Set up a simple visitor with multiple visit decorators. 176 class TestVisitor(MatcherDecoratableVisitor): 177 def __init__(self) -> None: 178 super().__init__() 179 self.visits: List[str] = [] 180 181 @call_if_inside(m.ClassDef(m.Name("A"))) 182 @call_if_inside(m.FunctionDef(m.Name("foo"))) 183 def visit_SimpleString(self, node: cst.SimpleString) -> None: 184 self.visits.append(node.value) 185 186 # Parse a module and verify we visited correctly. 187 module = fixture( 188 """ 189 def foo() -> None: 190 return "foo" 191 192 class A: 193 def foo(self) -> None: 194 return "baz" 195 """ 196 ) 197 visitor = TestVisitor() 198 module.visit(visitor) 199 200 # We should have only visited a select number of nodes. 201 self.assertEqual(visitor.visits, ['"baz"']) 202 203 def test_multiple_visitors_transform(self) -> None: 204 # Set up a simple visitor with multiple visit decorators. 205 class TestVisitor(MatcherDecoratableTransformer): 206 def __init__(self) -> None: 207 super().__init__() 208 self.visits: List[str] = [] 209 210 @call_if_inside(m.ClassDef(m.Name("A"))) 211 @call_if_inside(m.FunctionDef(m.Name("foo"))) 212 def visit_SimpleString(self, node: cst.SimpleString) -> None: 213 self.visits.append(node.value) 214 215 # Parse a module and verify we visited correctly. 216 module = fixture( 217 """ 218 def foo() -> None: 219 return "foo" 220 221 class A: 222 def foo(self) -> None: 223 return "baz" 224 """ 225 ) 226 visitor = TestVisitor() 227 module.visit(visitor) 228 229 # We should have only visited a select number of nodes. 230 self.assertEqual(visitor.visits, ['"baz"']) 231 232 def test_call_if_not_inside_transform_simple(self) -> None: 233 # Set up a simple visitor with a call_if_inside decorator. 234 class TestVisitor(MatcherDecoratableTransformer): 235 def __init__(self) -> None: 236 super().__init__() 237 self.visits: List[str] = [] 238 self.leaves: List[str] = [] 239 240 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 241 def visit_SimpleString(self, node: cst.SimpleString) -> None: 242 self.visits.append(node.value) 243 244 @call_if_not_inside(m.FunctionDef()) 245 def leave_SimpleString( 246 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 247 ) -> cst.SimpleString: 248 self.leaves.append(updated_node.value) 249 return updated_node 250 251 # Parse a module and verify we visited correctly. 252 module = fixture( 253 """ 254 a = "foo" 255 b = "bar" 256 257 def foo() -> None: 258 return "baz" 259 260 def bar() -> None: 261 return "foobar" 262 """ 263 ) 264 visitor = TestVisitor() 265 module.visit(visitor) 266 267 # We should have only visited a select number of nodes. 268 self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"']) 269 self.assertEqual(visitor.leaves, ['"foo"', '"bar"']) 270 271 def test_visit_if_inot_inside_verify_original_transform(self) -> None: 272 # Set up a simple visitor with a call_if_inside decorator. 273 class TestVisitor(MatcherDecoratableTransformer): 274 def __init__(self) -> None: 275 super().__init__() 276 self.func_visits: List[str] = [] 277 self.str_visits: List[str] = [] 278 279 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 280 def visit_SimpleString(self, node: cst.SimpleString) -> None: 281 self.str_visits.append(node.value) 282 283 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: 284 self.func_visits.append(node.name.value) 285 286 # Parse a module and verify we visited correctly. 287 module = fixture( 288 """ 289 a = "foo" 290 b = "bar" 291 292 def foo() -> None: 293 return "baz" 294 295 def bar() -> None: 296 return "foobar" 297 """ 298 ) 299 visitor = TestVisitor() 300 module.visit(visitor) 301 302 # We should have only visited a select number of nodes. 303 self.assertEqual(visitor.func_visits, ["foo", "bar"]) 304 self.assertEqual(visitor.str_visits, ['"foo"', '"bar"', '"foobar"']) 305 306 def test_call_if_not_inside_collect_simple(self) -> None: 307 # Set up a simple visitor with a call_if_inside decorator. 308 class TestVisitor(MatcherDecoratableVisitor): 309 def __init__(self) -> None: 310 super().__init__() 311 self.visits: List[str] = [] 312 self.leaves: List[str] = [] 313 314 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 315 def visit_SimpleString(self, node: cst.SimpleString) -> None: 316 self.visits.append(node.value) 317 318 @call_if_not_inside(m.FunctionDef()) 319 def leave_SimpleString(self, original_node: cst.SimpleString) -> None: 320 self.leaves.append(original_node.value) 321 322 # Parse a module and verify we visited correctly. 323 module = fixture( 324 """ 325 a = "foo" 326 b = "bar" 327 328 def foo() -> None: 329 return "baz" 330 331 def bar() -> None: 332 return "foobar" 333 """ 334 ) 335 visitor = TestVisitor() 336 module.visit(visitor) 337 338 # We should have only visited a select number of nodes. 339 self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"']) 340 self.assertEqual(visitor.leaves, ['"foo"', '"bar"']) 341 342 def test_visit_if_inot_inside_verify_original_collect(self) -> None: 343 # Set up a simple visitor with a call_if_inside decorator. 344 class TestVisitor(MatcherDecoratableVisitor): 345 def __init__(self) -> None: 346 super().__init__() 347 self.func_visits: List[str] = [] 348 self.str_visits: List[str] = [] 349 350 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 351 def visit_SimpleString(self, node: cst.SimpleString) -> None: 352 self.str_visits.append(node.value) 353 354 def visit_FunctionDef(self, node: cst.FunctionDef) -> None: 355 self.func_visits.append(node.name.value) 356 357 # Parse a module and verify we visited correctly. 358 module = fixture( 359 """ 360 a = "foo" 361 b = "bar" 362 363 def foo() -> None: 364 return "baz" 365 366 def bar() -> None: 367 return "foobar" 368 """ 369 ) 370 visitor = TestVisitor() 371 module.visit(visitor) 372 373 # We should have only visited a select number of nodes. 374 self.assertEqual(visitor.func_visits, ["foo", "bar"]) 375 self.assertEqual(visitor.str_visits, ['"foo"', '"bar"', '"foobar"']) 376 377 378class MatchersVisitLeaveDecoratorsTest(UnitTest): 379 def test_visit_transform(self) -> None: 380 # Set up a simple visitor with a visit and leave decorator. 381 class TestVisitor(MatcherDecoratableTransformer): 382 def __init__(self) -> None: 383 super().__init__() 384 self.visits: List[str] = [] 385 self.leaves: List[str] = [] 386 387 @visit(m.FunctionDef(m.Name("foo") | m.Name("bar"))) 388 def visit_function(self, node: cst.FunctionDef) -> None: 389 self.visits.append(node.name.value) 390 391 @leave(m.FunctionDef(m.Name("bar") | m.Name("baz"))) 392 def leave_function( 393 self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef 394 ) -> cst.FunctionDef: 395 self.leaves.append(updated_node.name.value) 396 return updated_node 397 398 # Parse a module and verify we visited correctly. 399 module = fixture( 400 """ 401 a = "foo" 402 b = "bar" 403 404 def foo() -> None: 405 return "baz" 406 407 def bar() -> None: 408 return "foobar" 409 410 def baz() -> None: 411 return "foobar" 412 """ 413 ) 414 visitor = TestVisitor() 415 module.visit(visitor) 416 417 # We should have only visited a select number of nodes. 418 self.assertEqual(visitor.visits, ["foo", "bar"]) 419 self.assertEqual(visitor.leaves, ["bar", "baz"]) 420 421 def test_visit_collector(self) -> None: 422 # Set up a simple visitor with a visit and leave decorator. 423 class TestVisitor(MatcherDecoratableVisitor): 424 def __init__(self) -> None: 425 super().__init__() 426 self.visits: List[str] = [] 427 self.leaves: List[str] = [] 428 429 @visit(m.FunctionDef(m.Name("foo") | m.Name("bar"))) 430 def visit_function(self, node: cst.FunctionDef) -> None: 431 self.visits.append(node.name.value) 432 433 @leave(m.FunctionDef(m.Name("bar") | m.Name("baz"))) 434 def leave_function(self, original_node: cst.FunctionDef) -> None: 435 self.leaves.append(original_node.name.value) 436 437 # Parse a module and verify we visited correctly. 438 module = fixture( 439 """ 440 a = "foo" 441 b = "bar" 442 443 def foo() -> None: 444 return "baz" 445 446 def bar() -> None: 447 return "foobar" 448 449 def baz() -> None: 450 return "foobar" 451 """ 452 ) 453 visitor = TestVisitor() 454 module.visit(visitor) 455 456 # We should have only visited a select number of nodes. 457 self.assertEqual(visitor.visits, ["foo", "bar"]) 458 self.assertEqual(visitor.leaves, ["bar", "baz"]) 459 460 def test_stacked_visit_transform(self) -> None: 461 # Set up a simple visitor with a visit and leave decorator. 462 class TestVisitor(MatcherDecoratableTransformer): 463 def __init__(self) -> None: 464 super().__init__() 465 self.visits: List[str] = [] 466 self.leaves: List[str] = [] 467 468 @visit(m.FunctionDef(m.Name("foo"))) 469 @visit(m.FunctionDef(m.Name("bar"))) 470 def visit_function(self, node: cst.FunctionDef) -> None: 471 self.visits.append(node.name.value) 472 473 @leave(m.FunctionDef(m.Name("bar"))) 474 @leave(m.FunctionDef(m.Name("baz"))) 475 def leave_function( 476 self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef 477 ) -> cst.FunctionDef: 478 self.leaves.append(updated_node.name.value) 479 return updated_node 480 481 # Parse a module and verify we visited correctly. 482 module = fixture( 483 """ 484 a = "foo" 485 b = "bar" 486 487 def foo() -> None: 488 return "baz" 489 490 def bar() -> None: 491 return "foobar" 492 493 def baz() -> None: 494 return "foobar" 495 """ 496 ) 497 visitor = TestVisitor() 498 module.visit(visitor) 499 500 # We should have only visited a select number of nodes. 501 self.assertEqual(visitor.visits, ["foo", "bar"]) 502 self.assertEqual(visitor.leaves, ["bar", "baz"]) 503 504 def test_stacked_visit_collector(self) -> None: 505 # Set up a simple visitor with a visit and leave decorator. 506 class TestVisitor(MatcherDecoratableVisitor): 507 def __init__(self) -> None: 508 super().__init__() 509 self.visits: List[str] = [] 510 self.leaves: List[str] = [] 511 512 @visit(m.FunctionDef(m.Name("foo"))) 513 @visit(m.FunctionDef(m.Name("bar"))) 514 def visit_function(self, node: cst.FunctionDef) -> None: 515 self.visits.append(node.name.value) 516 517 @leave(m.FunctionDef(m.Name("bar"))) 518 @leave(m.FunctionDef(m.Name("baz"))) 519 def leave_function(self, original_node: cst.FunctionDef) -> None: 520 self.leaves.append(original_node.name.value) 521 522 # Parse a module and verify we visited correctly. 523 module = fixture( 524 """ 525 a = "foo" 526 b = "bar" 527 528 def foo() -> None: 529 return "baz" 530 531 def bar() -> None: 532 return "foobar" 533 534 def baz() -> None: 535 return "foobar" 536 """ 537 ) 538 visitor = TestVisitor() 539 module.visit(visitor) 540 541 # We should have only visited a select number of nodes. 542 self.assertEqual(visitor.visits, ["foo", "bar"]) 543 self.assertEqual(visitor.leaves, ["bar", "baz"]) 544 self.assertEqual(visitor.leaves, ["bar", "baz"]) 545 546 def test_duplicate_visit_transform(self) -> None: 547 # Set up a simple visitor with a visit and leave decorator. 548 class TestVisitor(MatcherDecoratableTransformer): 549 def __init__(self) -> None: 550 super().__init__() 551 self.visits: Set[str] = set() 552 self.leaves: Set[str] = set() 553 554 @visit(m.FunctionDef(m.Name("foo"))) 555 def visit_function1(self, node: cst.FunctionDef) -> None: 556 self.visits.add(node.name.value + "1") 557 558 @visit(m.FunctionDef(m.Name("foo"))) 559 def visit_function2(self, node: cst.FunctionDef) -> None: 560 self.visits.add(node.name.value + "2") 561 562 @leave(m.FunctionDef(m.Name("bar"))) 563 def leave_function1( 564 self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef 565 ) -> cst.FunctionDef: 566 self.leaves.add(updated_node.name.value + "1") 567 return updated_node 568 569 @leave(m.FunctionDef(m.Name("bar"))) 570 def leave_function2( 571 self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef 572 ) -> cst.FunctionDef: 573 self.leaves.add(updated_node.name.value + "2") 574 return updated_node 575 576 # Parse a module and verify we visited correctly. 577 module = fixture( 578 """ 579 a = "foo" 580 b = "bar" 581 582 def foo() -> None: 583 return "baz" 584 585 def bar() -> None: 586 return "foobar" 587 588 def baz() -> None: 589 return "foobar" 590 """ 591 ) 592 visitor = TestVisitor() 593 module.visit(visitor) 594 595 # We should have only visited a select number of nodes. 596 self.assertEqual(visitor.visits, {"foo1", "foo2"}) 597 self.assertEqual(visitor.leaves, {"bar1", "bar2"}) 598 599 def test_duplicate_visit_collector(self) -> None: 600 # Set up a simple visitor with a visit and leave decorator. 601 class TestVisitor(MatcherDecoratableVisitor): 602 def __init__(self) -> None: 603 super().__init__() 604 self.visits: Set[str] = set() 605 self.leaves: Set[str] = set() 606 607 @visit(m.FunctionDef(m.Name("foo"))) 608 def visit_function1(self, node: cst.FunctionDef) -> None: 609 self.visits.add(node.name.value + "1") 610 611 @visit(m.FunctionDef(m.Name("foo"))) 612 def visit_function2(self, node: cst.FunctionDef) -> None: 613 self.visits.add(node.name.value + "2") 614 615 @leave(m.FunctionDef(m.Name("bar"))) 616 def leave_function1(self, original_node: cst.FunctionDef) -> None: 617 self.leaves.add(original_node.name.value + "1") 618 619 @leave(m.FunctionDef(m.Name("bar"))) 620 def leave_function2(self, original_node: cst.FunctionDef) -> None: 621 self.leaves.add(original_node.name.value + "2") 622 623 # Parse a module and verify we visited correctly. 624 module = fixture( 625 """ 626 a = "foo" 627 b = "bar" 628 629 def foo() -> None: 630 return "baz" 631 632 def bar() -> None: 633 return "foobar" 634 635 def baz() -> None: 636 return "foobar" 637 """ 638 ) 639 visitor = TestVisitor() 640 module.visit(visitor) 641 642 # We should have only visited a select number of nodes. 643 self.assertEqual(visitor.visits, {"foo1", "foo2"}) 644 self.assertEqual(visitor.leaves, {"bar1", "bar2"}) 645 646 def test_gated_visit_transform(self) -> None: 647 # Set up a simple visitor with a visit and leave decorator. 648 class TestVisitor(MatcherDecoratableTransformer): 649 def __init__(self) -> None: 650 super().__init__() 651 self.visits: Set[str] = set() 652 self.leaves: Set[str] = set() 653 654 @call_if_inside(m.FunctionDef(m.Name("foo"))) 655 @visit(m.SimpleString()) 656 def visit_string1(self, node: cst.SimpleString) -> None: 657 self.visits.add(literal_eval(node.value) + "1") 658 659 @call_if_not_inside(m.FunctionDef(m.Name("bar"))) 660 @visit(m.SimpleString()) 661 def visit_string2(self, node: cst.SimpleString) -> None: 662 self.visits.add(literal_eval(node.value) + "2") 663 664 @call_if_inside(m.FunctionDef(m.Name("baz"))) 665 @leave(m.SimpleString()) 666 def leave_string1( 667 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 668 ) -> cst.SimpleString: 669 self.leaves.add(literal_eval(updated_node.value) + "1") 670 return updated_node 671 672 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 673 @leave(m.SimpleString()) 674 def leave_string2( 675 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 676 ) -> cst.SimpleString: 677 self.leaves.add(literal_eval(updated_node.value) + "2") 678 return updated_node 679 680 # Parse a module and verify we visited correctly. 681 module = fixture( 682 """ 683 a = "foo" 684 b = "bar" 685 686 def foo() -> None: 687 return "baz" 688 689 def bar() -> None: 690 return "foobar" 691 692 def baz() -> None: 693 return "foobarbaz" 694 """ 695 ) 696 visitor = TestVisitor() 697 module.visit(visitor) 698 699 # We should have only visited a select number of nodes. 700 self.assertEqual(visitor.visits, {"baz1", "foo2", "bar2", "baz2", "foobarbaz2"}) 701 self.assertEqual( 702 visitor.leaves, {"foobarbaz1", "foo2", "bar2", "foobar2", "foobarbaz2"} 703 ) 704 705 def test_gated_visit_collect(self) -> None: 706 # Set up a simple visitor with a visit and leave decorator. 707 class TestVisitor(MatcherDecoratableVisitor): 708 def __init__(self) -> None: 709 super().__init__() 710 self.visits: Set[str] = set() 711 self.leaves: Set[str] = set() 712 713 @call_if_inside(m.FunctionDef(m.Name("foo"))) 714 @visit(m.SimpleString()) 715 def visit_string1(self, node: cst.SimpleString) -> None: 716 self.visits.add(literal_eval(node.value) + "1") 717 718 @call_if_not_inside(m.FunctionDef(m.Name("bar"))) 719 @visit(m.SimpleString()) 720 def visit_string2(self, node: cst.SimpleString) -> None: 721 self.visits.add(literal_eval(node.value) + "2") 722 723 @call_if_inside(m.FunctionDef(m.Name("baz"))) 724 @leave(m.SimpleString()) 725 def leave_string1(self, original_node: cst.SimpleString) -> None: 726 self.leaves.add(literal_eval(original_node.value) + "1") 727 728 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 729 @leave(m.SimpleString()) 730 def leave_string2(self, original_node: cst.SimpleString) -> None: 731 self.leaves.add(literal_eval(original_node.value) + "2") 732 733 # Parse a module and verify we visited correctly. 734 module = fixture( 735 """ 736 a = "foo" 737 b = "bar" 738 739 def foo() -> None: 740 return "baz" 741 742 def bar() -> None: 743 return "foobar" 744 745 def baz() -> None: 746 return "foobarbaz" 747 """ 748 ) 749 visitor = TestVisitor() 750 module.visit(visitor) 751 752 # We should have only visited a select number of nodes. 753 self.assertEqual(visitor.visits, {"baz1", "foo2", "bar2", "baz2", "foobarbaz2"}) 754 self.assertEqual( 755 visitor.leaves, {"foobarbaz1", "foo2", "bar2", "foobar2", "foobarbaz2"} 756 ) 757 758 def test_transform_order(self) -> None: 759 # Set up a simple visitor with a visit and leave decorator. 760 class TestVisitor(MatcherDecoratableTransformer): 761 @call_if_inside(m.FunctionDef(m.Name("bar"))) 762 @leave(m.SimpleString()) 763 def leave_string1( 764 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 765 ) -> cst.SimpleString: 766 return updated_node.with_changes( 767 value=f'"prefix{literal_eval(updated_node.value)}"' 768 ) 769 770 @call_if_inside(m.FunctionDef(m.Name("bar"))) 771 @leave(m.SimpleString()) 772 def leave_string2( 773 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 774 ) -> cst.SimpleString: 775 return updated_node.with_changes( 776 value=f'"{literal_eval(updated_node.value)}suffix"' 777 ) 778 779 @call_if_inside(m.FunctionDef(m.Name("bar"))) 780 def leave_SimpleString( 781 self, original_node: cst.SimpleString, updated_node: cst.SimpleString 782 ) -> cst.SimpleString: 783 return updated_node.with_changes( 784 value=f'"{"".join(reversed(literal_eval(updated_node.value)))}"' 785 ) 786 787 # Parse a module and verify we visited correctly. 788 module = fixture( 789 """ 790 a = "foo" 791 b = "bar" 792 793 def foo() -> None: 794 return "baz" 795 796 def bar() -> None: 797 return "foobar" 798 799 def baz() -> None: 800 return "foobarbaz" 801 """ 802 ) 803 visitor = TestVisitor() 804 actual = module.visit(visitor) 805 expected = fixture( 806 """ 807 a = "foo" 808 b = "bar" 809 810 def foo() -> None: 811 return "baz" 812 813 def bar() -> None: 814 return "prefixraboofsuffix" 815 816 def baz() -> None: 817 return "foobarbaz" 818 """ 819 ) 820 self.assertTrue(expected.deep_equals(actual)) 821 822 def test_call_if_inside_visitor_attribute(self) -> None: 823 # Set up a simple visitor with a call_if_inside decorator. 824 class TestVisitor(MatcherDecoratableVisitor): 825 def __init__(self) -> None: 826 super().__init__() 827 self.visits: List[str] = [] 828 self.leaves: List[str] = [] 829 830 @call_if_inside(m.FunctionDef(m.Name("foo"))) 831 def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None: 832 self.visits.append(node.value) 833 834 @call_if_inside(m.FunctionDef()) 835 def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: 836 self.leaves.append(node.value) 837 838 # Parse a module and verify we visited correctly. 839 module = fixture( 840 """ 841 a = "foo" 842 b = "bar" 843 844 def foo() -> None: 845 return "baz" 846 847 def bar() -> None: 848 return "foobar" 849 """ 850 ) 851 visitor = TestVisitor() 852 module.visit(visitor) 853 854 # We should have only visited a select number of nodes. 855 self.assertEqual(visitor.visits, ['"baz"']) 856 self.assertEqual(visitor.leaves, ['"baz"', '"foobar"']) 857 858 def test_call_if_inside_transform_attribute(self) -> None: 859 # Set up a simple visitor with a call_if_inside decorator. 860 class TestVisitor(MatcherDecoratableTransformer): 861 def __init__(self) -> None: 862 super().__init__() 863 self.visits: List[str] = [] 864 self.leaves: List[str] = [] 865 866 @call_if_inside(m.FunctionDef(m.Name("foo"))) 867 def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None: 868 self.visits.append(node.value) 869 870 @call_if_inside(m.FunctionDef()) 871 def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: 872 self.leaves.append(node.value) 873 874 # Parse a module and verify we visited correctly. 875 module = fixture( 876 """ 877 a = "foo" 878 b = "bar" 879 880 def foo() -> None: 881 return "baz" 882 883 def bar() -> None: 884 return "foobar" 885 """ 886 ) 887 visitor = TestVisitor() 888 module.visit(visitor) 889 890 # We should have only visited a select number of nodes. 891 self.assertEqual(visitor.visits, ['"baz"']) 892 self.assertEqual(visitor.leaves, ['"baz"', '"foobar"']) 893 894 def test_call_if_not_inside_visitor_attribute(self) -> None: 895 # Set up a simple visitor with a call_if_inside decorator. 896 class TestVisitor(MatcherDecoratableVisitor): 897 def __init__(self) -> None: 898 super().__init__() 899 self.visits: List[str] = [] 900 self.leaves: List[str] = [] 901 902 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 903 def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None: 904 self.visits.append(node.value) 905 906 @call_if_not_inside(m.FunctionDef()) 907 def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: 908 self.leaves.append(node.value) 909 910 # Parse a module and verify we visited correctly. 911 module = fixture( 912 """ 913 a = "foo" 914 b = "bar" 915 916 def foo() -> None: 917 return "baz" 918 919 def bar() -> None: 920 return "foobar" 921 """ 922 ) 923 visitor = TestVisitor() 924 module.visit(visitor) 925 926 # We should have only visited a select number of nodes. 927 self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"']) 928 self.assertEqual(visitor.leaves, ['"foo"', '"bar"']) 929 930 def test_call_if_not_inside_transform_attribute(self) -> None: 931 # Set up a simple visitor with a call_if_inside decorator. 932 class TestVisitor(MatcherDecoratableTransformer): 933 def __init__(self) -> None: 934 super().__init__() 935 self.visits: List[str] = [] 936 self.leaves: List[str] = [] 937 938 @call_if_not_inside(m.FunctionDef(m.Name("foo"))) 939 def visit_SimpleString_lpar(self, node: cst.SimpleString) -> None: 940 self.visits.append(node.value) 941 942 @call_if_not_inside(m.FunctionDef()) 943 def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: 944 self.leaves.append(node.value) 945 946 # Parse a module and verify we visited correctly. 947 module = fixture( 948 """ 949 a = "foo" 950 b = "bar" 951 952 def foo() -> None: 953 return "baz" 954 955 def bar() -> None: 956 return "foobar" 957 """ 958 ) 959 visitor = TestVisitor() 960 module.visit(visitor) 961 962 # We should have only visited a select number of nodes. 963 self.assertEqual(visitor.visits, ['"foo"', '"bar"', '"foobar"']) 964 self.assertEqual(visitor.leaves, ['"foo"', '"bar"']) 965 966 def test_init_with_unhashable_types(self) -> None: 967 # Set up a simple visitor with a call_if_inside decorator. 968 class TestVisitor(MatcherDecoratableTransformer): 969 def __init__(self) -> None: 970 super().__init__() 971 self.visits: List[str] = [] 972 973 @call_if_inside( 974 m.FunctionDef(m.Name("foo"), params=m.Parameters([m.ZeroOrMore()])) 975 ) 976 def visit_SimpleString(self, node: cst.SimpleString) -> None: 977 self.visits.append(node.value) 978 979 # Parse a module and verify we visited correctly. 980 module = fixture( 981 """ 982 a = "foo" 983 b = "bar" 984 985 def foo() -> None: 986 return "baz" 987 988 def bar() -> None: 989 return "foobar" 990 """ 991 ) 992 visitor = TestVisitor() 993 module.visit(visitor) 994 995 # We should have only visited a select number of nodes. 996 self.assertEqual(visitor.visits, ['"baz"']) 997 998 999# This is meant to simulate `cst.ImportFrom | cst.RemovalSentinel` in py3.10 1000FakeUnionClass: Mock = Mock() 1001setattr(FakeUnionClass, "__name__", "Union") 1002setattr(FakeUnionClass, "__module__", "types") 1003FakeUnion: Mock = Mock() 1004FakeUnion.__class__ = FakeUnionClass 1005FakeUnion.__args__ = [cst.ImportFrom, cst.RemovalSentinel] 1006 1007 1008class MatchersUnionDecoratorsTest(UnitTest): 1009 def test_init_with_new_union_annotation(self) -> None: 1010 class TransformerWithUnionReturnAnnotation(m.MatcherDecoratableTransformer): 1011 @m.leave(m.ImportFrom(module=m.Name(value="typing"))) 1012 def test( 1013 self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom 1014 ) -> FakeUnion: 1015 pass 1016 1017 # assert that init (specifically _check_types on return annotation) passes 1018 TransformerWithUnionReturnAnnotation() 1019