1""" 2Unit tests for refactor.py. 3""" 4 5import sys 6import os 7import codecs 8import io 9import re 10import tempfile 11import shutil 12import unittest 13 14from lib2to3 import refactor, pygram, fixer_base 15from lib2to3.pgen2 import token 16 17 18TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 19FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers") 20 21sys.path.append(FIXER_DIR) 22try: 23 _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes") 24finally: 25 sys.path.pop() 26 27_2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes") 28 29class TestRefactoringTool(unittest.TestCase): 30 31 def setUp(self): 32 sys.path.append(FIXER_DIR) 33 34 def tearDown(self): 35 sys.path.pop() 36 37 def check_instances(self, instances, classes): 38 for inst, cls in zip(instances, classes): 39 if not isinstance(inst, cls): 40 self.fail("%s are not instances of %s" % instances, classes) 41 42 def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None): 43 return refactor.RefactoringTool(fixers, options, explicit) 44 45 def test_print_function_option(self): 46 rt = self.rt({"print_function" : True}) 47 self.assertIs(rt.grammar, pygram.python_grammar_no_print_statement) 48 self.assertIs(rt.driver.grammar, 49 pygram.python_grammar_no_print_statement) 50 51 def test_write_unchanged_files_option(self): 52 rt = self.rt() 53 self.assertFalse(rt.write_unchanged_files) 54 rt = self.rt({"write_unchanged_files" : True}) 55 self.assertTrue(rt.write_unchanged_files) 56 57 def test_fixer_loading_helpers(self): 58 contents = ["explicit", "first", "last", "parrot", "preorder"] 59 non_prefixed = refactor.get_all_fix_names("myfixes") 60 prefixed = refactor.get_all_fix_names("myfixes", False) 61 full_names = refactor.get_fixers_from_package("myfixes") 62 self.assertEqual(prefixed, ["fix_" + name for name in contents]) 63 self.assertEqual(non_prefixed, contents) 64 self.assertEqual(full_names, 65 ["myfixes.fix_" + name for name in contents]) 66 67 def test_detect_future_features(self): 68 run = refactor._detect_future_features 69 fs = frozenset 70 empty = fs() 71 self.assertEqual(run(""), empty) 72 self.assertEqual(run("from __future__ import print_function"), 73 fs(("print_function",))) 74 self.assertEqual(run("from __future__ import generators"), 75 fs(("generators",))) 76 self.assertEqual(run("from __future__ import generators, feature"), 77 fs(("generators", "feature"))) 78 inp = "from __future__ import generators, print_function" 79 self.assertEqual(run(inp), fs(("generators", "print_function"))) 80 inp ="from __future__ import print_function, generators" 81 self.assertEqual(run(inp), fs(("print_function", "generators"))) 82 inp = "from __future__ import (print_function,)" 83 self.assertEqual(run(inp), fs(("print_function",))) 84 inp = "from __future__ import (generators, print_function)" 85 self.assertEqual(run(inp), fs(("generators", "print_function"))) 86 inp = "from __future__ import (generators, nested_scopes)" 87 self.assertEqual(run(inp), fs(("generators", "nested_scopes"))) 88 inp = """from __future__ import generators 89from __future__ import print_function""" 90 self.assertEqual(run(inp), fs(("generators", "print_function"))) 91 invalid = ("from", 92 "from 4", 93 "from x", 94 "from x 5", 95 "from x im", 96 "from x import", 97 "from x import 4", 98 ) 99 for inp in invalid: 100 self.assertEqual(run(inp), empty) 101 inp = "'docstring'\nfrom __future__ import print_function" 102 self.assertEqual(run(inp), fs(("print_function",))) 103 inp = "'docstring'\n'somng'\nfrom __future__ import print_function" 104 self.assertEqual(run(inp), empty) 105 inp = "# comment\nfrom __future__ import print_function" 106 self.assertEqual(run(inp), fs(("print_function",))) 107 inp = "# comment\n'doc'\nfrom __future__ import print_function" 108 self.assertEqual(run(inp), fs(("print_function",))) 109 inp = "class x: pass\nfrom __future__ import print_function" 110 self.assertEqual(run(inp), empty) 111 112 def test_get_headnode_dict(self): 113 class NoneFix(fixer_base.BaseFix): 114 pass 115 116 class FileInputFix(fixer_base.BaseFix): 117 PATTERN = "file_input< any * >" 118 119 class SimpleFix(fixer_base.BaseFix): 120 PATTERN = "'name'" 121 122 no_head = NoneFix({}, []) 123 with_head = FileInputFix({}, []) 124 simple = SimpleFix({}, []) 125 d = refactor._get_headnode_dict([no_head, with_head, simple]) 126 top_fixes = d.pop(pygram.python_symbols.file_input) 127 self.assertEqual(top_fixes, [with_head, no_head]) 128 name_fixes = d.pop(token.NAME) 129 self.assertEqual(name_fixes, [simple, no_head]) 130 for fixes in d.values(): 131 self.assertEqual(fixes, [no_head]) 132 133 def test_fixer_loading(self): 134 from myfixes.fix_first import FixFirst 135 from myfixes.fix_last import FixLast 136 from myfixes.fix_parrot import FixParrot 137 from myfixes.fix_preorder import FixPreorder 138 139 rt = self.rt() 140 pre, post = rt.get_fixers() 141 142 self.check_instances(pre, [FixPreorder]) 143 self.check_instances(post, [FixFirst, FixParrot, FixLast]) 144 145 def test_naughty_fixers(self): 146 self.assertRaises(ImportError, self.rt, fixers=["not_here"]) 147 self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"]) 148 self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"]) 149 150 def test_refactor_string(self): 151 rt = self.rt() 152 input = "def parrot(): pass\n\n" 153 tree = rt.refactor_string(input, "<test>") 154 self.assertNotEqual(str(tree), input) 155 156 input = "def f(): pass\n\n" 157 tree = rt.refactor_string(input, "<test>") 158 self.assertEqual(str(tree), input) 159 160 def test_refactor_stdin(self): 161 162 class MyRT(refactor.RefactoringTool): 163 164 def print_output(self, old_text, new_text, filename, equal): 165 results.extend([old_text, new_text, filename, equal]) 166 167 results = [] 168 rt = MyRT(_DEFAULT_FIXERS) 169 save = sys.stdin 170 sys.stdin = io.StringIO("def parrot(): pass\n\n") 171 try: 172 rt.refactor_stdin() 173 finally: 174 sys.stdin = save 175 expected = ["def parrot(): pass\n\n", 176 "def cheese(): pass\n\n", 177 "<stdin>", False] 178 self.assertEqual(results, expected) 179 180 def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS, 181 options=None, mock_log_debug=None, 182 actually_write=True): 183 test_file = self.init_test_file(test_file) 184 old_contents = self.read_file(test_file) 185 rt = self.rt(fixers=fixers, options=options) 186 if mock_log_debug: 187 rt.log_debug = mock_log_debug 188 189 rt.refactor_file(test_file) 190 self.assertEqual(old_contents, self.read_file(test_file)) 191 192 if not actually_write: 193 return 194 rt.refactor_file(test_file, True) 195 new_contents = self.read_file(test_file) 196 self.assertNotEqual(old_contents, new_contents) 197 return new_contents 198 199 def init_test_file(self, test_file): 200 tmpdir = tempfile.mkdtemp(prefix="2to3-test_refactor") 201 self.addCleanup(shutil.rmtree, tmpdir) 202 shutil.copy(test_file, tmpdir) 203 test_file = os.path.join(tmpdir, os.path.basename(test_file)) 204 os.chmod(test_file, 0o644) 205 return test_file 206 207 def read_file(self, test_file): 208 with open(test_file, "rb") as fp: 209 return fp.read() 210 211 def refactor_file(self, test_file, fixers=_2TO3_FIXERS): 212 test_file = self.init_test_file(test_file) 213 old_contents = self.read_file(test_file) 214 rt = self.rt(fixers=fixers) 215 rt.refactor_file(test_file, True) 216 new_contents = self.read_file(test_file) 217 return old_contents, new_contents 218 219 def test_refactor_file(self): 220 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 221 self.check_file_refactoring(test_file, _DEFAULT_FIXERS) 222 223 def test_refactor_file_write_unchanged_file(self): 224 test_file = os.path.join(FIXER_DIR, "parrot_example.py") 225 debug_messages = [] 226 def recording_log_debug(msg, *args): 227 debug_messages.append(msg % args) 228 self.check_file_refactoring(test_file, fixers=(), 229 options={"write_unchanged_files": True}, 230 mock_log_debug=recording_log_debug, 231 actually_write=False) 232 # Testing that it logged this message when write=False was passed is 233 # sufficient to see that it did not bail early after "No changes". 234 message_regex = r"Not writing changes to .*%s" % \ 235 re.escape(os.sep + os.path.basename(test_file)) 236 for message in debug_messages: 237 if "Not writing changes" in message: 238 self.assertRegex(message, message_regex) 239 break 240 else: 241 self.fail("%r not matched in %r" % (message_regex, debug_messages)) 242 243 def test_refactor_dir(self): 244 def check(structure, expected): 245 def mock_refactor_file(self, f, *args): 246 got.append(f) 247 save_func = refactor.RefactoringTool.refactor_file 248 refactor.RefactoringTool.refactor_file = mock_refactor_file 249 rt = self.rt() 250 got = [] 251 dir = tempfile.mkdtemp(prefix="2to3-test_refactor") 252 try: 253 os.mkdir(os.path.join(dir, "a_dir")) 254 for fn in structure: 255 open(os.path.join(dir, fn), "wb").close() 256 rt.refactor_dir(dir) 257 finally: 258 refactor.RefactoringTool.refactor_file = save_func 259 shutil.rmtree(dir) 260 self.assertEqual(got, 261 [os.path.join(dir, path) for path in expected]) 262 check([], []) 263 tree = ["nothing", 264 "hi.py", 265 ".dumb", 266 ".after.py", 267 "notpy.npy", 268 "sappy"] 269 expected = ["hi.py"] 270 check(tree, expected) 271 tree = ["hi.py", 272 os.path.join("a_dir", "stuff.py")] 273 check(tree, tree) 274 275 def test_file_encoding(self): 276 fn = os.path.join(TEST_DATA_DIR, "different_encoding.py") 277 self.check_file_refactoring(fn) 278 279 def test_false_file_encoding(self): 280 fn = os.path.join(TEST_DATA_DIR, "false_encoding.py") 281 data = self.check_file_refactoring(fn) 282 283 def test_bom(self): 284 fn = os.path.join(TEST_DATA_DIR, "bom.py") 285 data = self.check_file_refactoring(fn) 286 self.assertTrue(data.startswith(codecs.BOM_UTF8)) 287 288 def test_crlf_newlines(self): 289 old_sep = os.linesep 290 os.linesep = "\r\n" 291 try: 292 fn = os.path.join(TEST_DATA_DIR, "crlf.py") 293 fixes = refactor.get_fixers_from_package("lib2to3.fixes") 294 self.check_file_refactoring(fn, fixes) 295 finally: 296 os.linesep = old_sep 297 298 def test_crlf_unchanged(self): 299 fn = os.path.join(TEST_DATA_DIR, "crlf.py") 300 old, new = self.refactor_file(fn) 301 self.assertIn(b"\r\n", old) 302 self.assertIn(b"\r\n", new) 303 self.assertNotIn(b"\r\r\n", new) 304 305 def test_refactor_docstring(self): 306 rt = self.rt() 307 308 doc = """ 309>>> example() 31042 311""" 312 out = rt.refactor_docstring(doc, "<test>") 313 self.assertEqual(out, doc) 314 315 doc = """ 316>>> def parrot(): 317... return 43 318""" 319 out = rt.refactor_docstring(doc, "<test>") 320 self.assertNotEqual(out, doc) 321 322 def test_explicit(self): 323 from myfixes.fix_explicit import FixExplicit 324 325 rt = self.rt(fixers=["myfixes.fix_explicit"]) 326 self.assertEqual(len(rt.post_order), 0) 327 328 rt = self.rt(explicit=["myfixes.fix_explicit"]) 329 for fix in rt.post_order: 330 if isinstance(fix, FixExplicit): 331 break 332 else: 333 self.fail("explicit fixer not loaded") 334