1# Copyright 2006 Google, Inc. All Rights Reserved. 2# Licensed to PSF under a Contributor Agreement. 3 4"""Refactoring framework. 5 6Used as a main program, this can refactor any number of files and/or 7recursively descend down directories. Imported as a module, this 8provides infrastructure to write your own refactoring tool. 9""" 10 11__author__ = "Guido van Rossum <guido@python.org>" 12 13 14# Python imports 15import io 16import os 17import pkgutil 18import sys 19import logging 20import operator 21import collections 22from itertools import chain 23 24# Local imports 25from .pgen2 import driver, tokenize, token 26from .fixer_util import find_root 27from . import pytree, pygram 28from . import btm_matcher as bm 29 30 31def get_all_fix_names(fixer_pkg, remove_prefix=True): 32 """Return a sorted list of all available fix names in the given package.""" 33 pkg = __import__(fixer_pkg, [], [], ["*"]) 34 fix_names = [] 35 for finder, name, ispkg in pkgutil.iter_modules(pkg.__path__): 36 if name.startswith("fix_"): 37 if remove_prefix: 38 name = name[4:] 39 fix_names.append(name) 40 return fix_names 41 42 43class _EveryNode(Exception): 44 pass 45 46 47def _get_head_types(pat): 48 """ Accepts a pytree Pattern Node and returns a set 49 of the pattern types which will match first. """ 50 51 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)): 52 # NodePatters must either have no type and no content 53 # or a type and content -- so they don't get any farther 54 # Always return leafs 55 if pat.type is None: 56 raise _EveryNode 57 return {pat.type} 58 59 if isinstance(pat, pytree.NegatedPattern): 60 if pat.content: 61 return _get_head_types(pat.content) 62 raise _EveryNode # Negated Patterns don't have a type 63 64 if isinstance(pat, pytree.WildcardPattern): 65 # Recurse on each node in content 66 r = set() 67 for p in pat.content: 68 for x in p: 69 r.update(_get_head_types(x)) 70 return r 71 72 raise Exception("Oh no! I don't understand pattern %s" %(pat)) 73 74 75def _get_headnode_dict(fixer_list): 76 """ Accepts a list of fixers and returns a dictionary 77 of head node type --> fixer list. """ 78 head_nodes = collections.defaultdict(list) 79 every = [] 80 for fixer in fixer_list: 81 if fixer.pattern: 82 try: 83 heads = _get_head_types(fixer.pattern) 84 except _EveryNode: 85 every.append(fixer) 86 else: 87 for node_type in heads: 88 head_nodes[node_type].append(fixer) 89 else: 90 if fixer._accept_type is not None: 91 head_nodes[fixer._accept_type].append(fixer) 92 else: 93 every.append(fixer) 94 for node_type in chain(pygram.python_grammar.symbol2number.values(), 95 pygram.python_grammar.tokens): 96 head_nodes[node_type].extend(every) 97 return dict(head_nodes) 98 99 100def get_fixers_from_package(pkg_name): 101 """ 102 Return the fully qualified names for fixers in the package pkg_name. 103 """ 104 return [pkg_name + "." + fix_name 105 for fix_name in get_all_fix_names(pkg_name, False)] 106 107def _identity(obj): 108 return obj 109 110 111def _detect_future_features(source): 112 have_docstring = False 113 gen = tokenize.generate_tokens(io.StringIO(source).readline) 114 def advance(): 115 tok = next(gen) 116 return tok[0], tok[1] 117 ignore = frozenset({token.NEWLINE, tokenize.NL, token.COMMENT}) 118 features = set() 119 try: 120 while True: 121 tp, value = advance() 122 if tp in ignore: 123 continue 124 elif tp == token.STRING: 125 if have_docstring: 126 break 127 have_docstring = True 128 elif tp == token.NAME and value == "from": 129 tp, value = advance() 130 if tp != token.NAME or value != "__future__": 131 break 132 tp, value = advance() 133 if tp != token.NAME or value != "import": 134 break 135 tp, value = advance() 136 if tp == token.OP and value == "(": 137 tp, value = advance() 138 while tp == token.NAME: 139 features.add(value) 140 tp, value = advance() 141 if tp != token.OP or value != ",": 142 break 143 tp, value = advance() 144 else: 145 break 146 except StopIteration: 147 pass 148 return frozenset(features) 149 150 151class FixerError(Exception): 152 """A fixer could not be loaded.""" 153 154 155class RefactoringTool(object): 156 157 _default_options = {"print_function" : False, 158 "write_unchanged_files" : False} 159 160 CLASS_PREFIX = "Fix" # The prefix for fixer classes 161 FILE_PREFIX = "fix_" # The prefix for modules with a fixer within 162 163 def __init__(self, fixer_names, options=None, explicit=None): 164 """Initializer. 165 166 Args: 167 fixer_names: a list of fixers to import 168 options: a dict with configuration. 169 explicit: a list of fixers to run even if they are explicit. 170 """ 171 self.fixers = fixer_names 172 self.explicit = explicit or [] 173 self.options = self._default_options.copy() 174 if options is not None: 175 self.options.update(options) 176 if self.options["print_function"]: 177 self.grammar = pygram.python_grammar_no_print_statement 178 else: 179 self.grammar = pygram.python_grammar 180 # When this is True, the refactor*() methods will call write_file() for 181 # files processed even if they were not changed during refactoring. If 182 # and only if the refactor method's write parameter was True. 183 self.write_unchanged_files = self.options.get("write_unchanged_files") 184 self.errors = [] 185 self.logger = logging.getLogger("RefactoringTool") 186 self.fixer_log = [] 187 self.wrote = False 188 self.driver = driver.Driver(self.grammar, 189 convert=pytree.convert, 190 logger=self.logger) 191 self.pre_order, self.post_order = self.get_fixers() 192 193 194 self.files = [] # List of files that were or should be modified 195 196 self.BM = bm.BottomMatcher() 197 self.bmi_pre_order = [] # Bottom Matcher incompatible fixers 198 self.bmi_post_order = [] 199 200 for fixer in chain(self.post_order, self.pre_order): 201 if fixer.BM_compatible: 202 self.BM.add_fixer(fixer) 203 # remove fixers that will be handled by the bottom-up 204 # matcher 205 elif fixer in self.pre_order: 206 self.bmi_pre_order.append(fixer) 207 elif fixer in self.post_order: 208 self.bmi_post_order.append(fixer) 209 210 self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order) 211 self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order) 212 213 214 215 def get_fixers(self): 216 """Inspects the options to load the requested patterns and handlers. 217 218 Returns: 219 (pre_order, post_order), where pre_order is the list of fixers that 220 want a pre-order AST traversal, and post_order is the list that want 221 post-order traversal. 222 """ 223 pre_order_fixers = [] 224 post_order_fixers = [] 225 for fix_mod_path in self.fixers: 226 mod = __import__(fix_mod_path, {}, {}, ["*"]) 227 fix_name = fix_mod_path.rsplit(".", 1)[-1] 228 if fix_name.startswith(self.FILE_PREFIX): 229 fix_name = fix_name[len(self.FILE_PREFIX):] 230 parts = fix_name.split("_") 231 class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts]) 232 try: 233 fix_class = getattr(mod, class_name) 234 except AttributeError: 235 raise FixerError("Can't find %s.%s" % (fix_name, class_name)) from None 236 fixer = fix_class(self.options, self.fixer_log) 237 if fixer.explicit and self.explicit is not True and \ 238 fix_mod_path not in self.explicit: 239 self.log_message("Skipping optional fixer: %s", fix_name) 240 continue 241 242 self.log_debug("Adding transformation: %s", fix_name) 243 if fixer.order == "pre": 244 pre_order_fixers.append(fixer) 245 elif fixer.order == "post": 246 post_order_fixers.append(fixer) 247 else: 248 raise FixerError("Illegal fixer order: %r" % fixer.order) 249 250 key_func = operator.attrgetter("run_order") 251 pre_order_fixers.sort(key=key_func) 252 post_order_fixers.sort(key=key_func) 253 return (pre_order_fixers, post_order_fixers) 254 255 def log_error(self, msg, *args, **kwds): 256 """Called when an error occurs.""" 257 raise 258 259 def log_message(self, msg, *args): 260 """Hook to log a message.""" 261 if args: 262 msg = msg % args 263 self.logger.info(msg) 264 265 def log_debug(self, msg, *args): 266 if args: 267 msg = msg % args 268 self.logger.debug(msg) 269 270 def print_output(self, old_text, new_text, filename, equal): 271 """Called with the old version, new version, and filename of a 272 refactored file.""" 273 pass 274 275 def refactor(self, items, write=False, doctests_only=False): 276 """Refactor a list of files and directories.""" 277 278 for dir_or_file in items: 279 if os.path.isdir(dir_or_file): 280 self.refactor_dir(dir_or_file, write, doctests_only) 281 else: 282 self.refactor_file(dir_or_file, write, doctests_only) 283 284 def refactor_dir(self, dir_name, write=False, doctests_only=False): 285 """Descends down a directory and refactor every Python file found. 286 287 Python files are assumed to have a .py extension. 288 289 Files and subdirectories starting with '.' are skipped. 290 """ 291 py_ext = os.extsep + "py" 292 for dirpath, dirnames, filenames in os.walk(dir_name): 293 self.log_debug("Descending into %s", dirpath) 294 dirnames.sort() 295 filenames.sort() 296 for name in filenames: 297 if (not name.startswith(".") and 298 os.path.splitext(name)[1] == py_ext): 299 fullname = os.path.join(dirpath, name) 300 self.refactor_file(fullname, write, doctests_only) 301 # Modify dirnames in-place to remove subdirs with leading dots 302 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] 303 304 def _read_python_source(self, filename): 305 """ 306 Do our best to decode a Python source file correctly. 307 """ 308 try: 309 f = open(filename, "rb") 310 except OSError as err: 311 self.log_error("Can't open %s: %s", filename, err) 312 return None, None 313 try: 314 encoding = tokenize.detect_encoding(f.readline)[0] 315 finally: 316 f.close() 317 with io.open(filename, "r", encoding=encoding, newline='') as f: 318 return f.read(), encoding 319 320 def refactor_file(self, filename, write=False, doctests_only=False): 321 """Refactors a file.""" 322 input, encoding = self._read_python_source(filename) 323 if input is None: 324 # Reading the file failed. 325 return 326 input += "\n" # Silence certain parse errors 327 if doctests_only: 328 self.log_debug("Refactoring doctests in %s", filename) 329 output = self.refactor_docstring(input, filename) 330 if self.write_unchanged_files or output != input: 331 self.processed_file(output, filename, input, write, encoding) 332 else: 333 self.log_debug("No doctest changes in %s", filename) 334 else: 335 tree = self.refactor_string(input, filename) 336 if self.write_unchanged_files or (tree and tree.was_changed): 337 # The [:-1] is to take off the \n we added earlier 338 self.processed_file(str(tree)[:-1], filename, 339 write=write, encoding=encoding) 340 else: 341 self.log_debug("No changes in %s", filename) 342 343 def refactor_string(self, data, name): 344 """Refactor a given input string. 345 346 Args: 347 data: a string holding the code to be refactored. 348 name: a human-readable name for use in error/log messages. 349 350 Returns: 351 An AST corresponding to the refactored input stream; None if 352 there were errors during the parse. 353 """ 354 features = _detect_future_features(data) 355 if "print_function" in features: 356 self.driver.grammar = pygram.python_grammar_no_print_statement 357 try: 358 tree = self.driver.parse_string(data) 359 except Exception as err: 360 self.log_error("Can't parse %s: %s: %s", 361 name, err.__class__.__name__, err) 362 return 363 finally: 364 self.driver.grammar = self.grammar 365 tree.future_features = features 366 self.log_debug("Refactoring %s", name) 367 self.refactor_tree(tree, name) 368 return tree 369 370 def refactor_stdin(self, doctests_only=False): 371 input = sys.stdin.read() 372 if doctests_only: 373 self.log_debug("Refactoring doctests in stdin") 374 output = self.refactor_docstring(input, "<stdin>") 375 if self.write_unchanged_files or output != input: 376 self.processed_file(output, "<stdin>", input) 377 else: 378 self.log_debug("No doctest changes in stdin") 379 else: 380 tree = self.refactor_string(input, "<stdin>") 381 if self.write_unchanged_files or (tree and tree.was_changed): 382 self.processed_file(str(tree), "<stdin>", input) 383 else: 384 self.log_debug("No changes in stdin") 385 386 def refactor_tree(self, tree, name): 387 """Refactors a parse tree (modifying the tree in place). 388 389 For compatible patterns the bottom matcher module is 390 used. Otherwise the tree is traversed node-to-node for 391 matches. 392 393 Args: 394 tree: a pytree.Node instance representing the root of the tree 395 to be refactored. 396 name: a human-readable name for this tree. 397 398 Returns: 399 True if the tree was modified, False otherwise. 400 """ 401 402 for fixer in chain(self.pre_order, self.post_order): 403 fixer.start_tree(tree, name) 404 405 #use traditional matching for the incompatible fixers 406 self.traverse_by(self.bmi_pre_order_heads, tree.pre_order()) 407 self.traverse_by(self.bmi_post_order_heads, tree.post_order()) 408 409 # obtain a set of candidate nodes 410 match_set = self.BM.run(tree.leaves()) 411 412 while any(match_set.values()): 413 for fixer in self.BM.fixers: 414 if fixer in match_set and match_set[fixer]: 415 #sort by depth; apply fixers from bottom(of the AST) to top 416 match_set[fixer].sort(key=pytree.Base.depth, reverse=True) 417 418 if fixer.keep_line_order: 419 #some fixers(eg fix_imports) must be applied 420 #with the original file's line order 421 match_set[fixer].sort(key=pytree.Base.get_lineno) 422 423 for node in list(match_set[fixer]): 424 if node in match_set[fixer]: 425 match_set[fixer].remove(node) 426 427 try: 428 find_root(node) 429 except ValueError: 430 # this node has been cut off from a 431 # previous transformation ; skip 432 continue 433 434 if node.fixers_applied and fixer in node.fixers_applied: 435 # do not apply the same fixer again 436 continue 437 438 results = fixer.match(node) 439 440 if results: 441 new = fixer.transform(node, results) 442 if new is not None: 443 node.replace(new) 444 #new.fixers_applied.append(fixer) 445 for node in new.post_order(): 446 # do not apply the fixer again to 447 # this or any subnode 448 if not node.fixers_applied: 449 node.fixers_applied = [] 450 node.fixers_applied.append(fixer) 451 452 # update the original match set for 453 # the added code 454 new_matches = self.BM.run(new.leaves()) 455 for fxr in new_matches: 456 if not fxr in match_set: 457 match_set[fxr]=[] 458 459 match_set[fxr].extend(new_matches[fxr]) 460 461 for fixer in chain(self.pre_order, self.post_order): 462 fixer.finish_tree(tree, name) 463 return tree.was_changed 464 465 def traverse_by(self, fixers, traversal): 466 """Traverse an AST, applying a set of fixers to each node. 467 468 This is a helper method for refactor_tree(). 469 470 Args: 471 fixers: a list of fixer instances. 472 traversal: a generator that yields AST nodes. 473 474 Returns: 475 None 476 """ 477 if not fixers: 478 return 479 for node in traversal: 480 for fixer in fixers[node.type]: 481 results = fixer.match(node) 482 if results: 483 new = fixer.transform(node, results) 484 if new is not None: 485 node.replace(new) 486 node = new 487 488 def processed_file(self, new_text, filename, old_text=None, write=False, 489 encoding=None): 490 """ 491 Called when a file has been refactored and there may be changes. 492 """ 493 self.files.append(filename) 494 if old_text is None: 495 old_text = self._read_python_source(filename)[0] 496 if old_text is None: 497 return 498 equal = old_text == new_text 499 self.print_output(old_text, new_text, filename, equal) 500 if equal: 501 self.log_debug("No changes to %s", filename) 502 if not self.write_unchanged_files: 503 return 504 if write: 505 self.write_file(new_text, filename, old_text, encoding) 506 else: 507 self.log_debug("Not writing changes to %s", filename) 508 509 def write_file(self, new_text, filename, old_text, encoding=None): 510 """Writes a string to a file. 511 512 It first shows a unified diff between the old text and the new text, and 513 then rewrites the file; the latter is only done if the write option is 514 set. 515 """ 516 try: 517 fp = io.open(filename, "w", encoding=encoding, newline='') 518 except OSError as err: 519 self.log_error("Can't create %s: %s", filename, err) 520 return 521 522 with fp: 523 try: 524 fp.write(new_text) 525 except OSError as err: 526 self.log_error("Can't write %s: %s", filename, err) 527 self.log_debug("Wrote changes to %s", filename) 528 self.wrote = True 529 530 PS1 = ">>> " 531 PS2 = "... " 532 533 def refactor_docstring(self, input, filename): 534 """Refactors a docstring, looking for doctests. 535 536 This returns a modified version of the input string. It looks 537 for doctests, which start with a ">>>" prompt, and may be 538 continued with "..." prompts, as long as the "..." is indented 539 the same as the ">>>". 540 541 (Unfortunately we can't use the doctest module's parser, 542 since, like most parsers, it is not geared towards preserving 543 the original source.) 544 """ 545 result = [] 546 block = None 547 block_lineno = None 548 indent = None 549 lineno = 0 550 for line in input.splitlines(keepends=True): 551 lineno += 1 552 if line.lstrip().startswith(self.PS1): 553 if block is not None: 554 result.extend(self.refactor_doctest(block, block_lineno, 555 indent, filename)) 556 block_lineno = lineno 557 block = [line] 558 i = line.find(self.PS1) 559 indent = line[:i] 560 elif (indent is not None and 561 (line.startswith(indent + self.PS2) or 562 line == indent + self.PS2.rstrip() + "\n")): 563 block.append(line) 564 else: 565 if block is not None: 566 result.extend(self.refactor_doctest(block, block_lineno, 567 indent, filename)) 568 block = None 569 indent = None 570 result.append(line) 571 if block is not None: 572 result.extend(self.refactor_doctest(block, block_lineno, 573 indent, filename)) 574 return "".join(result) 575 576 def refactor_doctest(self, block, lineno, indent, filename): 577 """Refactors one doctest. 578 579 A doctest is given as a block of lines, the first of which starts 580 with ">>>" (possibly indented), while the remaining lines start 581 with "..." (identically indented). 582 583 """ 584 try: 585 tree = self.parse_block(block, lineno, indent) 586 except Exception as err: 587 if self.logger.isEnabledFor(logging.DEBUG): 588 for line in block: 589 self.log_debug("Source: %s", line.rstrip("\n")) 590 self.log_error("Can't parse docstring in %s line %s: %s: %s", 591 filename, lineno, err.__class__.__name__, err) 592 return block 593 if self.refactor_tree(tree, filename): 594 new = str(tree).splitlines(keepends=True) 595 # Undo the adjustment of the line numbers in wrap_toks() below. 596 clipped, new = new[:lineno-1], new[lineno-1:] 597 assert clipped == ["\n"] * (lineno-1), clipped 598 if not new[-1].endswith("\n"): 599 new[-1] += "\n" 600 block = [indent + self.PS1 + new.pop(0)] 601 if new: 602 block += [indent + self.PS2 + line for line in new] 603 return block 604 605 def summarize(self): 606 if self.wrote: 607 were = "were" 608 else: 609 were = "need to be" 610 if not self.files: 611 self.log_message("No files %s modified.", were) 612 else: 613 self.log_message("Files that %s modified:", were) 614 for file in self.files: 615 self.log_message(file) 616 if self.fixer_log: 617 self.log_message("Warnings/messages while refactoring:") 618 for message in self.fixer_log: 619 self.log_message(message) 620 if self.errors: 621 if len(self.errors) == 1: 622 self.log_message("There was 1 error:") 623 else: 624 self.log_message("There were %d errors:", len(self.errors)) 625 for msg, args, kwds in self.errors: 626 self.log_message(msg, *args, **kwds) 627 628 def parse_block(self, block, lineno, indent): 629 """Parses a block into a tree. 630 631 This is necessary to get correct line number / offset information 632 in the parser diagnostics and embedded into the parse tree. 633 """ 634 tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent)) 635 tree.future_features = frozenset() 636 return tree 637 638 def wrap_toks(self, block, lineno, indent): 639 """Wraps a tokenize stream to systematically modify start/end.""" 640 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).__next__) 641 for type, value, (line0, col0), (line1, col1), line_text in tokens: 642 line0 += lineno - 1 643 line1 += lineno - 1 644 # Don't bother updating the columns; this is too complicated 645 # since line_text would also have to be updated and it would 646 # still break for tokens spanning lines. Let the user guess 647 # that the column numbers for doctests are relative to the 648 # end of the prompt string (PS1 or PS2). 649 yield type, value, (line0, col0), (line1, col1), line_text 650 651 652 def gen_lines(self, block, indent): 653 """Generates lines as expected by tokenize from a list of lines. 654 655 This strips the first len(indent + self.PS1) characters off each line. 656 """ 657 prefix1 = indent + self.PS1 658 prefix2 = indent + self.PS2 659 prefix = prefix1 660 for line in block: 661 if line.startswith(prefix): 662 yield line[len(prefix):] 663 elif line == prefix.rstrip() + "\n": 664 yield "\n" 665 else: 666 raise AssertionError("line=%r, prefix=%r" % (line, prefix)) 667 prefix = prefix2 668 while True: 669 yield "" 670 671 672class MultiprocessingUnsupported(Exception): 673 pass 674 675 676class MultiprocessRefactoringTool(RefactoringTool): 677 678 def __init__(self, *args, **kwargs): 679 super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs) 680 self.queue = None 681 self.output_lock = None 682 683 def refactor(self, items, write=False, doctests_only=False, 684 num_processes=1): 685 if num_processes == 1: 686 return super(MultiprocessRefactoringTool, self).refactor( 687 items, write, doctests_only) 688 try: 689 import multiprocessing 690 except ImportError: 691 raise MultiprocessingUnsupported 692 if self.queue is not None: 693 raise RuntimeError("already doing multiple processes") 694 self.queue = multiprocessing.JoinableQueue() 695 self.output_lock = multiprocessing.Lock() 696 processes = [multiprocessing.Process(target=self._child) 697 for i in range(num_processes)] 698 try: 699 for p in processes: 700 p.start() 701 super(MultiprocessRefactoringTool, self).refactor(items, write, 702 doctests_only) 703 finally: 704 self.queue.join() 705 for i in range(num_processes): 706 self.queue.put(None) 707 for p in processes: 708 if p.is_alive(): 709 p.join() 710 self.queue = None 711 712 def _child(self): 713 task = self.queue.get() 714 while task is not None: 715 args, kwargs = task 716 try: 717 super(MultiprocessRefactoringTool, self).refactor_file( 718 *args, **kwargs) 719 finally: 720 self.queue.task_done() 721 task = self.queue.get() 722 723 def refactor_file(self, *args, **kwargs): 724 if self.queue is not None: 725 self.queue.put((args, kwargs)) 726 else: 727 return super(MultiprocessRefactoringTool, self).refactor_file( 728 *args, **kwargs) 729