1""" 2Unit tests for refactor.py. 3""" 4 5from __future__ import with_statement 6 7import sys 8import os 9import codecs 10import operator 11import re 12import StringIO 13import tempfile 14import shutil 15import unittest 16import warnings 17 18from lib2to3 import refactor, pygram, fixer_base 19from lib2to3.pgen2 import token 20 21from . import support 22 23 24TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 25FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers") 26 27sys.path.append(FIXER_DIR) 28try: 29 _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes") 30finally: 31 sys.path.pop() 32 33_2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes") 34 35class TestRefactoringTool(unittest.TestCase): 36 37 def setUp(self): 38 sys.path.append(FIXER_DIR) 39 40 def tearDown(self): 41 sys.path.pop() 42 43 def check_instances(self, instances, classes): 44 for inst, cls in zip(instances, classes): 45 if not isinstance(inst, cls): 46 self.fail("%s are not instances of %s" % instances, classes) 47 48 def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): 49 return refactor.RefactoringTool(fixers, options, explicit) 50 51 def test_print_function_option(self): 52 rt = self.rt({"print_function" : True}) 53 self.assertIs(rt.grammar, pygram.python_grammar_no_print_statement) 54 self.assertIs(rt.driver.grammar, 55 pygram.python_grammar_no_print_statement) 56 57 def test_write_unchanged_files_option(self): 58 rt = self.rt() 59 self.assertFalse(rt.write_unchanged_files) 60 rt = self.rt({"write_unchanged_files" : True}) 61 self.assertTrue(rt.write_unchanged_files) 62 63 def test_fixer_loading_helpers(self): 64 contents = ["explicit", "first", "last", "parrot", "preorder"] 65 non_prefixed = refactor.get_all_fix_names("myfixes") 66 prefixed = refactor.get_all_fix_names("myfixes", False) 67 full_names = refactor.get_fixers_from_package("myfixes") 68 self.assertEqual(prefixed, ["fix_" + name for name in contents]) 69 self.assertEqual(non_prefixed, contents) 70 self.assertEqual(full_names, 71 ["myfixes.fix_" + name for name in contents]) 72 73 def test_detect_future_features(self): 74 run = refactor._detect_future_features 75 fs = frozenset 76 empty = fs() 77 self.assertEqual(run(""), empty) 78 self.assertEqual(run("from __future__ import print_function"), 79 fs(("print_function",))) 80 self.assertEqual(run("from __future__ import generators"), 81 fs(("generators",))) 82 self.assertEqual(run("from __future__ import generators, feature"), 83 fs(("generators", "feature"))) 84 inp = "from __future__ import generators, print_function" 85 self.assertEqual(run(inp), fs(("generators", "print_function"))) 86 inp ="from __future__ import print_function, generators" 87 self.assertEqual(run(inp), fs(("print_function", "generators"))) 88 inp = "from __future__ import (print_function,)" 89 self.assertEqual(run(inp), fs(("print_function",))) 90 inp = "from __future__ import (generators, print_function)" 91 self.assertEqual(run(inp), fs(("generators", "print_function"))) 92 inp = "from __future__ import (generators, nested_scopes)" 93 self.assertEqual(run(inp), fs(("generators", "nested_scopes"))) 94 inp = """from __future__ import generators 95from __future__ import print_function""" 96 self.assertEqual(run(inp), fs(("generators", "print_function"))) 97 invalid = ("from", 98 "from 4", 99 "from x", 100 "from x 5", 101 "from x im", 102 "from x import", 103 "from x import 4", 104 ) 105 for inp in invalid: 106 self.assertEqual(run(inp), empty) 107 inp = "'docstring'\nfrom __future__ import print_function" 108 self.assertEqual(run(inp), fs(("print_function",))) 109 inp = "'docstring'\n'somng'\nfrom __future__ import print_function" 110 self.assertEqual(run(inp), empty) 111 inp = "# comment\nfrom __future__ import print_function" 112 self.assertEqual(run(inp), fs(("print_function",))) 113 inp = "# comment\n'doc'\nfrom __future__ import print_function" 114 self.assertEqual(run(inp), fs(("print_function",))) 115 inp = "class x: pass\nfrom __future__ import print_function" 116 self.assertEqual(run(inp), empty) 117 118 def test_get_headnode_dict(self): 119 class NoneFix(fixer_base.BaseFix): 120 pass 121 122 class FileInputFix(fixer_base.BaseFix): 123 PATTERN = "file_input< any * >" 124 125 class SimpleFix(fixer_base.BaseFix): 126 PATTERN = "'name'" 127 128 no_head = NoneFix({}, []) 129 with_head = FileInputFix({}, []) 130 simple = SimpleFix({}, []) 131 d = refactor._get_headnode_dict([no_head, with_head, simple]) 132 top_fixes = d.pop(pygram.python_symbols.file_input) 133 self.assertEqual(top_fixes, [with_head, no_head]) 134 name_fixes = d.pop(token.NAME) 135 self.assertEqual(name_fixes, [simple, no_head]) 136 for fixes in d.itervalues(): 137 self.assertEqual(fixes, [no_head]) 138 139 def test_fixer_loading(self): 140 from myfixes.fix_first import FixFirst 141 from myfixes.fix_last import FixLast 142 from myfixes.fix_parrot import FixParrot 143 from myfixes.fix_preorder import FixPreorder 144 145 rt = self.rt() 146 pre, post = rt.get_fixers() 147 148 self.check_instances(pre, [FixPreorder]) 149 self.check_instances(post, [FixFirst, FixParrot, FixLast]) 150 151 def test_naughty_fixers(self): 152 self.assertRaises(ImportError, self.rt, fixers=["not_here"]) 153 self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"]) 154 self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"]) 155 156 def test_refactor_string(self): 157 rt = self.rt() 158 input = "def parrot(): pass\n\n" 159 tree = rt.refactor_string(input, "<test>") 160 self.assertNotEqual(str(tree), input) 161 162 input = "def f(): pass\n\n" 163 tree = rt.refactor_string(input, "<test>") 164 self.assertEqual(str(tree), input) 165 166 def test_refactor_stdin(self): 167 168 class MyRT(refactor.RefactoringTool): 169 170 def print_output(self, old_text, new_text, filename, equal): 171 results.extend([old_text, new_text, filename, equal]) 172 173 results = [] 174 rt = MyRT(_DEFAULT_FIXERS) 175 save = sys.stdin 176 sys.stdin = StringIO.StringIO("def parrot(): pass\n\n") 177 try: 178 rt.refactor_stdin() 179 finally: 180 sys.stdin = save 181 expected = ["def parrot(): pass\n\n", 182 "def cheese(): pass\n\n", 183 "<stdin>", False] 184 self.assertEqual(results, expected) 185 186 def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS, 187 options=None, mock_log_debug=None, 188 actually_write=True): 189 tmpdir = tempfile.mkdtemp(prefix="2to3-test_refactor") 190 self.addCleanup(shutil.rmtree, tmpdir) 191 # make a copy of the tested file that we can write to 192 shutil.copy(test_file, tmpdir) 193 test_file = os.path.join(tmpdir, os.path.basename(test_file)) 194 os.chmod(test_file, 0o644) 195 196 def read_file(): 197 with open(test_file, "rb") as fp: 198 return fp.read() 199 200 old_contents = read_file() 201 rt = self.rt(fixers=fixers, options=options) 202 if mock_log_debug: 203 rt.log_debug = mock_log_debug 204 205 rt.refactor_file(test_file) 206 self.assertEqual(old_contents, read_file()) 207 208 if not actually_write: 209 return 210 rt.refactor_file(test_file, True) 211 new_contents = read_file() 212 self.assertNotEqual(old_contents, new_contents) 213 return new_contents 214 215 def test_refactor_file(self): 216 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 217 self.check_file_refactoring(test_file, _DEFAULT_FIXERS) 218 219 def test_refactor_file_write_unchanged_file(self): 220 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 221 debug_messages = [] 222 def recording_log_debug(msg, *args): 223 debug_messages.append(msg % args) 224 self.check_file_refactoring(test_file, fixers=(), 225 options={"write_unchanged_files": True}, 226 mock_log_debug=recording_log_debug, 227 actually_write=False) 228 # Testing that it logged this message when write=False was passed is 229 # sufficient to see that it did not bail early after "No changes". 230 message_regex = r"Not writing changes to .*%s" % \ 231 re.escape(os.sep + os.path.basename(test_file)) 232 for message in debug_messages: 233 if "Not writing changes" in message: 234 self.assertRegexpMatches(message, message_regex) 235 break 236 else: 237 self.fail("%r not matched in %r" % (message_regex, debug_messages)) 238 239 def test_refactor_dir(self): 240 def check(structure, expected): 241 def mock_refactor_file(self, f, *args): 242 got.append(f) 243 save_func = refactor.RefactoringTool.refactor_file 244 refactor.RefactoringTool.refactor_file = mock_refactor_file 245 rt = self.rt() 246 got = [] 247 dir = tempfile.mkdtemp(prefix="2to3-test_refactor") 248 try: 249 os.mkdir(os.path.join(dir, "a_dir")) 250 for fn in structure: 251 open(os.path.join(dir, fn), "wb").close() 252 rt.refactor_dir(dir) 253 finally: 254 refactor.RefactoringTool.refactor_file = save_func 255 shutil.rmtree(dir) 256 self.assertEqual(got, 257 [os.path.join(dir, path) for path in expected]) 258 check([], []) 259 tree = ["nothing", 260 "hi.py", 261 ".dumb", 262 ".after.py", 263 "notpy.npy", 264 "sappy"] 265 expected = ["hi.py"] 266 check(tree, expected) 267 tree = ["hi.py", 268 os.path.join("a_dir", "stuff.py")] 269 check(tree, tree) 270 271 def test_file_encoding(self): 272 fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") 273 self.check_file_refactoring(fn) 274 275 def test_false_file_encoding(self): 276 fn = os.path.join(TEST_DATA_DIR, "false_encoding.py") 277 data = self.check_file_refactoring(fn) 278 279 def test_bom(self): 280 fn = os.path.join(TEST_DATA_DIR, "bom.py") 281 data = self.check_file_refactoring(fn) 282 self.assertTrue(data.startswith(codecs.BOM_UTF8)) 283 284 def test_crlf_newlines(self): 285 old_sep = os.linesep 286 os.linesep = "\r\n" 287 try: 288 fn = os.path.join(TEST_DATA_DIR, "crlf.py") 289 fixes = refactor.get_fixers_from_package("lib2to3.fixes") 290 self.check_file_refactoring(fn, fixes) 291 finally: 292 os.linesep = old_sep 293 294 def test_refactor_docstring(self): 295 rt = self.rt() 296 297 doc = """ 298>>> example() 29942 300""" 301 out = rt.refactor_docstring(doc, "<test>") 302 self.assertEqual(out, doc) 303 304 doc = """ 305>>> def parrot(): 306... return 43 307""" 308 out = rt.refactor_docstring(doc, "<test>") 309 self.assertNotEqual(out, doc) 310 311 def test_explicit(self): 312 from myfixes.fix_explicit import FixExplicit 313 314 rt = self.rt(fixers=["myfixes.fix_explicit"]) 315 self.assertEqual(len(rt.post_order), 0) 316 317 rt = self.rt(explicit=["myfixes.fix_explicit"]) 318 for fix in rt.post_order: 319 if isinstance(fix, FixExplicit): 320 break 321 else: 322 self.fail("explicit fixer not loaded") 323