1# Copyright 2006 Google, Inc. All Rights Reserved. 2# Licensed to PSF under a Contributor Agreement. 3 4"""Refactoring framework. 5 6Used as a main program, this can refactor any number of files and/or 7recursively descend down directories. Imported as a module, this 8provides infrastructure to write your own refactoring tool. 9""" 10 11__author__ = "Guido van Rossum <guido@python.org>" 12 13 14# Python imports 15import io 16import os 17import pkgutil 18import sys 19import logging 20import operator 21import collections 22from itertools import chain 23 24# Local imports 25from .pgen2 import driver, tokenize, token 26from .fixer_util import find_root 27from . import pytree, pygram 28from . import btm_matcher as bm 29 30 31def get_all_fix_names(fixer_pkg, remove_prefix=True): 32 """Return a sorted list of all available fix names in the given package.""" 33 pkg = __import__(fixer_pkg, [], [], ["*"]) 34 fix_names = [] 35 for finder, name, ispkg in pkgutil.iter_modules(pkg.__path__): 36 if name.startswith("fix_"): 37 if remove_prefix: 38 name = name[4:] 39 fix_names.append(name) 40 return fix_names 41 42 43class _EveryNode(Exception): 44 pass 45 46 47def _get_head_types(pat): 48 """ Accepts a pytree Pattern Node and returns a set 49 of the pattern types which will match first. """ 50 51 if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)): 52 # NodePatters must either have no type and no content 53 # or a type and content -- so they don't get any farther 54 # Always return leafs 55 if pat.type is None: 56 raise _EveryNode 57 return {pat.type} 58 59 if isinstance(pat, pytree.NegatedPattern): 60 if pat.content: 61 return _get_head_types(pat.content) 62 raise _EveryNode # Negated Patterns don't have a type 63 64 if isinstance(pat, pytree.WildcardPattern): 65 # Recurse on each node in content 66 r = set() 67 for p in pat.content: 68 for x in p: 69 r.update(_get_head_types(x)) 70 return r 71 72 raise Exception("Oh no! I don't understand pattern %s" %(pat)) 73 74 75def _get_headnode_dict(fixer_list): 76 """ Accepts a list of fixers and returns a dictionary 77 of head node type --> fixer list. """ 78 head_nodes = collections.defaultdict(list) 79 every = [] 80 for fixer in fixer_list: 81 if fixer.pattern: 82 try: 83 heads = _get_head_types(fixer.pattern) 84 except _EveryNode: 85 every.append(fixer) 86 else: 87 for node_type in heads: 88 head_nodes[node_type].append(fixer) 89 else: 90 if fixer._accept_type is not None: 91 head_nodes[fixer._accept_type].append(fixer) 92 else: 93 every.append(fixer) 94 for node_type in chain(pygram.python_grammar.symbol2number.values(), 95 pygram.python_grammar.tokens): 96 head_nodes[node_type].extend(every) 97 return dict(head_nodes) 98 99 100def get_fixers_from_package(pkg_name): 101 """ 102 Return the fully qualified names for fixers in the package pkg_name. 103 """ 104 return [pkg_name + "." + fix_name 105 for fix_name in get_all_fix_names(pkg_name, False)] 106 107def _identity(obj): 108 return obj 109 110 111def _detect_future_features(source): 112 have_docstring = False 113 gen = tokenize.generate_tokens(io.StringIO(source).readline) 114 def advance(): 115 tok = next(gen) 116 return tok[0], tok[1] 117 ignore = frozenset({token.NEWLINE, tokenize.NL, token.COMMENT}) 118 features = set() 119 try: 120 while True: 121 tp, value = advance() 122 if tp in ignore: 123 continue 124 elif tp == token.STRING: 125 if have_docstring: 126 break 127 have_docstring = True 128 elif tp == token.NAME and value == "from": 129 tp, value = advance() 130 if tp != token.NAME or value != "__future__": 131 break 132 tp, value = advance() 133 if tp != token.NAME or value != "import": 134 break 135 tp, value = advance() 136 if tp == token.OP and value == "(": 137 tp, value = advance() 138 while tp == token.NAME: 139 features.add(value) 140 tp, value = advance() 141 if tp != token.OP or value != ",": 142 break 143 tp, value = advance() 144 else: 145 break 146 except StopIteration: 147 pass 148 return frozenset(features) 149 150 151class FixerError(Exception): 152 """A fixer could not be loaded.""" 153 154 155class RefactoringTool(object): 156 157 _default_options = {"print_function" : False, 158 "exec_function": False, 159 "write_unchanged_files" : False} 160 161 CLASS_PREFIX = "Fix" # The prefix for fixer classes 162 FILE_PREFIX = "fix_" # The prefix for modules with a fixer within 163 164 def __init__(self, fixer_names, options=None, explicit=None): 165 """Initializer. 166 167 Args: 168 fixer_names: a list of fixers to import 169 options: a dict with configuration. 170 explicit: a list of fixers to run even if they are explicit. 171 """ 172 self.fixers = fixer_names 173 self.explicit = explicit or [] 174 self.options = self._default_options.copy() 175 if options is not None: 176 self.options.update(options) 177 self.grammar = pygram.python_grammar.copy() 178 179 if self.options['print_function']: 180 del self.grammar.keywords["print"] 181 elif self.options['exec_function']: 182 del self.grammar.keywords["exec"] 183 184 # When this is True, the refactor*() methods will call write_file() for 185 # files processed even if they were not changed during refactoring. If 186 # and only if the refactor method's write parameter was True. 187 self.write_unchanged_files = self.options.get("write_unchanged_files") 188 self.errors = [] 189 self.logger = logging.getLogger("RefactoringTool") 190 self.fixer_log = [] 191 self.wrote = False 192 self.driver = driver.Driver(self.grammar, 193 convert=pytree.convert, 194 logger=self.logger) 195 self.pre_order, self.post_order = self.get_fixers() 196 197 198 self.files = [] # List of files that were or should be modified 199 200 self.BM = bm.BottomMatcher() 201 self.bmi_pre_order = [] # Bottom Matcher incompatible fixers 202 self.bmi_post_order = [] 203 204 for fixer in chain(self.post_order, self.pre_order): 205 if fixer.BM_compatible: 206 self.BM.add_fixer(fixer) 207 # remove fixers that will be handled by the bottom-up 208 # matcher 209 elif fixer in self.pre_order: 210 self.bmi_pre_order.append(fixer) 211 elif fixer in self.post_order: 212 self.bmi_post_order.append(fixer) 213 214 self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order) 215 self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order) 216 217 218 219 def get_fixers(self): 220 """Inspects the options to load the requested patterns and handlers. 221 222 Returns: 223 (pre_order, post_order), where pre_order is the list of fixers that 224 want a pre-order AST traversal, and post_order is the list that want 225 post-order traversal. 226 """ 227 pre_order_fixers = [] 228 post_order_fixers = [] 229 for fix_mod_path in self.fixers: 230 mod = __import__(fix_mod_path, {}, {}, ["*"]) 231 fix_name = fix_mod_path.rsplit(".", 1)[-1] 232 if fix_name.startswith(self.FILE_PREFIX): 233 fix_name = fix_name[len(self.FILE_PREFIX):] 234 parts = fix_name.split("_") 235 class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts]) 236 try: 237 fix_class = getattr(mod, class_name) 238 except AttributeError: 239 raise FixerError("Can't find %s.%s" % (fix_name, class_name)) from None 240 fixer = fix_class(self.options, self.fixer_log) 241 if fixer.explicit and self.explicit is not True and \ 242 fix_mod_path not in self.explicit: 243 self.log_message("Skipping optional fixer: %s", fix_name) 244 continue 245 246 self.log_debug("Adding transformation: %s", fix_name) 247 if fixer.order == "pre": 248 pre_order_fixers.append(fixer) 249 elif fixer.order == "post": 250 post_order_fixers.append(fixer) 251 else: 252 raise FixerError("Illegal fixer order: %r" % fixer.order) 253 254 key_func = operator.attrgetter("run_order") 255 pre_order_fixers.sort(key=key_func) 256 post_order_fixers.sort(key=key_func) 257 return (pre_order_fixers, post_order_fixers) 258 259 def log_error(self, msg, *args, **kwds): 260 """Called when an error occurs.""" 261 raise 262 263 def log_message(self, msg, *args): 264 """Hook to log a message.""" 265 if args: 266 msg = msg % args 267 self.logger.info(msg) 268 269 def log_debug(self, msg, *args): 270 if args: 271 msg = msg % args 272 self.logger.debug(msg) 273 274 def print_output(self, old_text, new_text, filename, equal): 275 """Called with the old version, new version, and filename of a 276 refactored file.""" 277 pass 278 279 def refactor(self, items, write=False, doctests_only=False): 280 """Refactor a list of files and directories.""" 281 282 for dir_or_file in items: 283 if os.path.isdir(dir_or_file): 284 self.refactor_dir(dir_or_file, write, doctests_only) 285 else: 286 self.refactor_file(dir_or_file, write, doctests_only) 287 288 def refactor_dir(self, dir_name, write=False, doctests_only=False): 289 """Descends down a directory and refactor every Python file found. 290 291 Python files are assumed to have a .py extension. 292 293 Files and subdirectories starting with '.' are skipped. 294 """ 295 py_ext = os.extsep + "py" 296 for dirpath, dirnames, filenames in os.walk(dir_name): 297 self.log_debug("Descending into %s", dirpath) 298 dirnames.sort() 299 filenames.sort() 300 for name in filenames: 301 if (not name.startswith(".") and 302 os.path.splitext(name)[1] == py_ext): 303 fullname = os.path.join(dirpath, name) 304 self.refactor_file(fullname, write, doctests_only) 305 # Modify dirnames in-place to remove subdirs with leading dots 306 dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")] 307 308 def _read_python_source(self, filename): 309 """ 310 Do our best to decode a Python source file correctly. 311 """ 312 try: 313 f = open(filename, "rb") 314 except OSError as err: 315 self.log_error("Can't open %s: %s", filename, err) 316 return None, None 317 try: 318 encoding = tokenize.detect_encoding(f.readline)[0] 319 finally: 320 f.close() 321 with io.open(filename, "r", encoding=encoding, newline='') as f: 322 return f.read(), encoding 323 324 def refactor_file(self, filename, write=False, doctests_only=False): 325 """Refactors a file.""" 326 input, encoding = self._read_python_source(filename) 327 if input is None: 328 # Reading the file failed. 329 return 330 input += "\n" # Silence certain parse errors 331 if doctests_only: 332 self.log_debug("Refactoring doctests in %s", filename) 333 output = self.refactor_docstring(input, filename) 334 if self.write_unchanged_files or output != input: 335 self.processed_file(output, filename, input, write, encoding) 336 else: 337 self.log_debug("No doctest changes in %s", filename) 338 else: 339 tree = self.refactor_string(input, filename) 340 if self.write_unchanged_files or (tree and tree.was_changed): 341 # The [:-1] is to take off the \n we added earlier 342 self.processed_file(str(tree)[:-1], filename, 343 write=write, encoding=encoding) 344 else: 345 self.log_debug("No changes in %s", filename) 346 347 def refactor_string(self, data, name): 348 """Refactor a given input string. 349 350 Args: 351 data: a string holding the code to be refactored. 352 name: a human-readable name for use in error/log messages. 353 354 Returns: 355 An AST corresponding to the refactored input stream; None if 356 there were errors during the parse. 357 """ 358 features = _detect_future_features(data) 359 if "print_function" in features: 360 self.driver.grammar = pygram.python_grammar_no_print_statement 361 try: 362 tree = self.driver.parse_string(data) 363 except Exception as err: 364 self.log_error("Can't parse %s: %s: %s", 365 name, err.__class__.__name__, err) 366 return 367 finally: 368 self.driver.grammar = self.grammar 369 tree.future_features = features 370 self.log_debug("Refactoring %s", name) 371 self.refactor_tree(tree, name) 372 return tree 373 374 def refactor_stdin(self, doctests_only=False): 375 input = sys.stdin.read() 376 if doctests_only: 377 self.log_debug("Refactoring doctests in stdin") 378 output = self.refactor_docstring(input, "<stdin>") 379 if self.write_unchanged_files or output != input: 380 self.processed_file(output, "<stdin>", input) 381 else: 382 self.log_debug("No doctest changes in stdin") 383 else: 384 tree = self.refactor_string(input, "<stdin>") 385 if self.write_unchanged_files or (tree and tree.was_changed): 386 self.processed_file(str(tree), "<stdin>", input) 387 else: 388 self.log_debug("No changes in stdin") 389 390 def refactor_tree(self, tree, name): 391 """Refactors a parse tree (modifying the tree in place). 392 393 For compatible patterns the bottom matcher module is 394 used. Otherwise the tree is traversed node-to-node for 395 matches. 396 397 Args: 398 tree: a pytree.Node instance representing the root of the tree 399 to be refactored. 400 name: a human-readable name for this tree. 401 402 Returns: 403 True if the tree was modified, False otherwise. 404 """ 405 406 for fixer in chain(self.pre_order, self.post_order): 407 fixer.start_tree(tree, name) 408 409 #use traditional matching for the incompatible fixers 410 self.traverse_by(self.bmi_pre_order_heads, tree.pre_order()) 411 self.traverse_by(self.bmi_post_order_heads, tree.post_order()) 412 413 # obtain a set of candidate nodes 414 match_set = self.BM.run(tree.leaves()) 415 416 while any(match_set.values()): 417 for fixer in self.BM.fixers: 418 if fixer in match_set and match_set[fixer]: 419 #sort by depth; apply fixers from bottom(of the AST) to top 420 match_set[fixer].sort(key=pytree.Base.depth, reverse=True) 421 422 if fixer.keep_line_order: 423 #some fixers(eg fix_imports) must be applied 424 #with the original file's line order 425 match_set[fixer].sort(key=pytree.Base.get_lineno) 426 427 for node in list(match_set[fixer]): 428 if node in match_set[fixer]: 429 match_set[fixer].remove(node) 430 431 try: 432 find_root(node) 433 except ValueError: 434 # this node has been cut off from a 435 # previous transformation ; skip 436 continue 437 438 if node.fixers_applied and fixer in node.fixers_applied: 439 # do not apply the same fixer again 440 continue 441 442 results = fixer.match(node) 443 444 if results: 445 new = fixer.transform(node, results) 446 if new is not None: 447 node.replace(new) 448 #new.fixers_applied.append(fixer) 449 for node in new.post_order(): 450 # do not apply the fixer again to 451 # this or any subnode 452 if not node.fixers_applied: 453 node.fixers_applied = [] 454 node.fixers_applied.append(fixer) 455 456 # update the original match set for 457 # the added code 458 new_matches = self.BM.run(new.leaves()) 459 for fxr in new_matches: 460 if not fxr in match_set: 461 match_set[fxr]=[] 462 463 match_set[fxr].extend(new_matches[fxr]) 464 465 for fixer in chain(self.pre_order, self.post_order): 466 fixer.finish_tree(tree, name) 467 return tree.was_changed 468 469 def traverse_by(self, fixers, traversal): 470 """Traverse an AST, applying a set of fixers to each node. 471 472 This is a helper method for refactor_tree(). 473 474 Args: 475 fixers: a list of fixer instances. 476 traversal: a generator that yields AST nodes. 477 478 Returns: 479 None 480 """ 481 if not fixers: 482 return 483 for node in traversal: 484 for fixer in fixers[node.type]: 485 results = fixer.match(node) 486 if results: 487 new = fixer.transform(node, results) 488 if new is not None: 489 node.replace(new) 490 node = new 491 492 def processed_file(self, new_text, filename, old_text=None, write=False, 493 encoding=None): 494 """ 495 Called when a file has been refactored and there may be changes. 496 """ 497 self.files.append(filename) 498 if old_text is None: 499 old_text = self._read_python_source(filename)[0] 500 if old_text is None: 501 return 502 equal = old_text == new_text 503 self.print_output(old_text, new_text, filename, equal) 504 if equal: 505 self.log_debug("No changes to %s", filename) 506 if not self.write_unchanged_files: 507 return 508 if write: 509 self.write_file(new_text, filename, old_text, encoding) 510 else: 511 self.log_debug("Not writing changes to %s", filename) 512 513 def write_file(self, new_text, filename, old_text, encoding=None): 514 """Writes a string to a file. 515 516 It first shows a unified diff between the old text and the new text, and 517 then rewrites the file; the latter is only done if the write option is 518 set. 519 """ 520 try: 521 fp = io.open(filename, "w", encoding=encoding, newline='') 522 except OSError as err: 523 self.log_error("Can't create %s: %s", filename, err) 524 return 525 526 with fp: 527 try: 528 fp.write(new_text) 529 except OSError as err: 530 self.log_error("Can't write %s: %s", filename, err) 531 self.log_debug("Wrote changes to %s", filename) 532 self.wrote = True 533 534 PS1 = ">>> " 535 PS2 = "... " 536 537 def refactor_docstring(self, input, filename): 538 """Refactors a docstring, looking for doctests. 539 540 This returns a modified version of the input string. It looks 541 for doctests, which start with a ">>>" prompt, and may be 542 continued with "..." prompts, as long as the "..." is indented 543 the same as the ">>>". 544 545 (Unfortunately we can't use the doctest module's parser, 546 since, like most parsers, it is not geared towards preserving 547 the original source.) 548 """ 549 result = [] 550 block = None 551 block_lineno = None 552 indent = None 553 lineno = 0 554 for line in input.splitlines(keepends=True): 555 lineno += 1 556 if line.lstrip().startswith(self.PS1): 557 if block is not None: 558 result.extend(self.refactor_doctest(block, block_lineno, 559 indent, filename)) 560 block_lineno = lineno 561 block = [line] 562 i = line.find(self.PS1) 563 indent = line[:i] 564 elif (indent is not None and 565 (line.startswith(indent + self.PS2) or 566 line == indent + self.PS2.rstrip() + "\n")): 567 block.append(line) 568 else: 569 if block is not None: 570 result.extend(self.refactor_doctest(block, block_lineno, 571 indent, filename)) 572 block = None 573 indent = None 574 result.append(line) 575 if block is not None: 576 result.extend(self.refactor_doctest(block, block_lineno, 577 indent, filename)) 578 return "".join(result) 579 580 def refactor_doctest(self, block, lineno, indent, filename): 581 """Refactors one doctest. 582 583 A doctest is given as a block of lines, the first of which starts 584 with ">>>" (possibly indented), while the remaining lines start 585 with "..." (identically indented). 586 587 """ 588 try: 589 tree = self.parse_block(block, lineno, indent) 590 except Exception as err: 591 if self.logger.isEnabledFor(logging.DEBUG): 592 for line in block: 593 self.log_debug("Source: %s", line.rstrip("\n")) 594 self.log_error("Can't parse docstring in %s line %s: %s: %s", 595 filename, lineno, err.__class__.__name__, err) 596 return block 597 if self.refactor_tree(tree, filename): 598 new = str(tree).splitlines(keepends=True) 599 # Undo the adjustment of the line numbers in wrap_toks() below. 600 clipped, new = new[:lineno-1], new[lineno-1:] 601 assert clipped == ["\n"] * (lineno-1), clipped 602 if not new[-1].endswith("\n"): 603 new[-1] += "\n" 604 block = [indent + self.PS1 + new.pop(0)] 605 if new: 606 block += [indent + self.PS2 + line for line in new] 607 return block 608 609 def summarize(self): 610 if self.wrote: 611 were = "were" 612 else: 613 were = "need to be" 614 if not self.files: 615 self.log_message("No files %s modified.", were) 616 else: 617 self.log_message("Files that %s modified:", were) 618 for file in self.files: 619 self.log_message(file) 620 if self.fixer_log: 621 self.log_message("Warnings/messages while refactoring:") 622 for message in self.fixer_log: 623 self.log_message(message) 624 if self.errors: 625 if len(self.errors) == 1: 626 self.log_message("There was 1 error:") 627 else: 628 self.log_message("There were %d errors:", len(self.errors)) 629 for msg, args, kwds in self.errors: 630 self.log_message(msg, *args, **kwds) 631 632 def parse_block(self, block, lineno, indent): 633 """Parses a block into a tree. 634 635 This is necessary to get correct line number / offset information 636 in the parser diagnostics and embedded into the parse tree. 637 """ 638 tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent)) 639 tree.future_features = frozenset() 640 return tree 641 642 def wrap_toks(self, block, lineno, indent): 643 """Wraps a tokenize stream to systematically modify start/end.""" 644 tokens = tokenize.generate_tokens(self.gen_lines(block, indent).__next__) 645 for type, value, (line0, col0), (line1, col1), line_text in tokens: 646 line0 += lineno - 1 647 line1 += lineno - 1 648 # Don't bother updating the columns; this is too complicated 649 # since line_text would also have to be updated and it would 650 # still break for tokens spanning lines. Let the user guess 651 # that the column numbers for doctests are relative to the 652 # end of the prompt string (PS1 or PS2). 653 yield type, value, (line0, col0), (line1, col1), line_text 654 655 656 def gen_lines(self, block, indent): 657 """Generates lines as expected by tokenize from a list of lines. 658 659 This strips the first len(indent + self.PS1) characters off each line. 660 """ 661 prefix1 = indent + self.PS1 662 prefix2 = indent + self.PS2 663 prefix = prefix1 664 for line in block: 665 if line.startswith(prefix): 666 yield line[len(prefix):] 667 elif line == prefix.rstrip() + "\n": 668 yield "\n" 669 else: 670 raise AssertionError("line=%r, prefix=%r" % (line, prefix)) 671 prefix = prefix2 672 while True: 673 yield "" 674 675 676class MultiprocessingUnsupported(Exception): 677 pass 678 679 680class MultiprocessRefactoringTool(RefactoringTool): 681 682 def __init__(self, *args, **kwargs): 683 super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs) 684 self.queue = None 685 self.output_lock = None 686 687 def refactor(self, items, write=False, doctests_only=False, 688 num_processes=1): 689 if num_processes == 1: 690 return super(MultiprocessRefactoringTool, self).refactor( 691 items, write, doctests_only) 692 try: 693 import multiprocessing 694 except ImportError: 695 raise MultiprocessingUnsupported 696 if self.queue is not None: 697 raise RuntimeError("already doing multiple processes") 698 self.queue = multiprocessing.JoinableQueue() 699 self.output_lock = multiprocessing.Lock() 700 processes = [multiprocessing.Process(target=self._child) 701 for i in range(num_processes)] 702 try: 703 for p in processes: 704 p.start() 705 super(MultiprocessRefactoringTool, self).refactor(items, write, 706 doctests_only) 707 finally: 708 self.queue.join() 709 for i in range(num_processes): 710 self.queue.put(None) 711 for p in processes: 712 if p.is_alive(): 713 p.join() 714 self.queue = None 715 716 def _child(self): 717 task = self.queue.get() 718 while task is not None: 719 args, kwargs = task 720 try: 721 super(MultiprocessRefactoringTool, self).refactor_file( 722 *args, **kwargs) 723 finally: 724 self.queue.task_done() 725 task = self.queue.get() 726 727 def refactor_file(self, *args, **kwargs): 728 if self.queue is not None: 729 self.queue.put((args, kwargs)) 730 else: 731 return super(MultiprocessRefactoringTool, self).refactor_file( 732 *args, **kwargs) 733