1"""This module can be used for finding similar code""" 2import re 3 4import rope.refactor.wildcards 5from rope.base import libutils 6from rope.base import codeanalyze, exceptions, ast, builtins 7from rope.refactor import (patchedast, wildcards) 8 9from rope.refactor.patchedast import MismatchedTokenError 10 11 12class BadNameInCheckError(exceptions.RefactoringError): 13 pass 14 15 16class SimilarFinder(object): 17 """`SimilarFinder` can be used to find similar pieces of code 18 19 See the notes in the `rope.refactor.restructure` module for more 20 info. 21 22 """ 23 24 def __init__(self, pymodule, wildcards=None): 25 """Construct a SimilarFinder""" 26 self.source = pymodule.source_code 27 try: 28 self.raw_finder = RawSimilarFinder( 29 pymodule.source_code, pymodule.get_ast(), self._does_match) 30 except MismatchedTokenError: 31 print("in file %s" % pymodule.resource.path) 32 raise 33 self.pymodule = pymodule 34 if wildcards is None: 35 self.wildcards = {} 36 for wildcard in [rope.refactor.wildcards. 37 DefaultWildcard(pymodule.pycore.project)]: 38 self.wildcards[wildcard.get_name()] = wildcard 39 else: 40 self.wildcards = wildcards 41 42 def get_matches(self, code, args={}, start=0, end=None): 43 self.args = args 44 if end is None: 45 end = len(self.source) 46 skip_region = None 47 if 'skip' in args.get('', {}): 48 resource, region = args['']['skip'] 49 if resource == self.pymodule.get_resource(): 50 skip_region = region 51 return self.raw_finder.get_matches(code, start=start, end=end, 52 skip=skip_region) 53 54 def get_match_regions(self, *args, **kwds): 55 for match in self.get_matches(*args, **kwds): 56 yield match.get_region() 57 58 def _does_match(self, node, name): 59 arg = self.args.get(name, '') 60 kind = 'default' 61 if isinstance(arg, (tuple, list)): 62 kind = arg[0] 63 arg = arg[1] 64 suspect = wildcards.Suspect(self.pymodule, node, name) 65 return self.wildcards[kind].matches(suspect, arg) 66 67 68class RawSimilarFinder(object): 69 """A class for finding similar expressions and statements""" 70 71 def __init__(self, source, node=None, does_match=None): 72 if node is None: 73 node = ast.parse(source) 74 if does_match is None: 75 self.does_match = self._simple_does_match 76 else: 77 self.does_match = does_match 78 self._init_using_ast(node, source) 79 80 def _simple_does_match(self, node, name): 81 return isinstance(node, (ast.expr, ast.Name)) 82 83 def _init_using_ast(self, node, source): 84 self.source = source 85 self._matched_asts = {} 86 if not hasattr(node, 'region'): 87 patchedast.patch_ast(node, source) 88 self.ast = node 89 90 def get_matches(self, code, start=0, end=None, skip=None): 91 """Search for `code` in source and return a list of `Match`\es 92 93 `code` can contain wildcards. ``${name}`` matches normal 94 names and ``${?name} can match any expression. You can use 95 `Match.get_ast()` for getting the node that has matched a 96 given pattern. 97 98 """ 99 if end is None: 100 end = len(self.source) 101 for match in self._get_matched_asts(code): 102 match_start, match_end = match.get_region() 103 if start <= match_start and match_end <= end: 104 if skip is not None and (skip[0] < match_end and 105 skip[1] > match_start): 106 continue 107 yield match 108 109 def _get_matched_asts(self, code): 110 if code not in self._matched_asts: 111 wanted = self._create_pattern(code) 112 matches = _ASTMatcher(self.ast, wanted, 113 self.does_match).find_matches() 114 self._matched_asts[code] = matches 115 return self._matched_asts[code] 116 117 def _create_pattern(self, expression): 118 expression = self._replace_wildcards(expression) 119 node = ast.parse(expression) 120 # Getting Module.Stmt.nodes 121 nodes = node.body 122 if len(nodes) == 1 and isinstance(nodes[0], ast.Expr): 123 # Getting Discard.expr 124 wanted = nodes[0].value 125 else: 126 wanted = nodes 127 return wanted 128 129 def _replace_wildcards(self, expression): 130 ropevar = _RopeVariable() 131 template = CodeTemplate(expression) 132 mapping = {} 133 for name in template.get_names(): 134 mapping[name] = ropevar.get_var(name) 135 return template.substitute(mapping) 136 137 138class _ASTMatcher(object): 139 140 def __init__(self, body, pattern, does_match): 141 """Searches the given pattern in the body AST. 142 143 body is an AST node and pattern can be either an AST node or 144 a list of ASTs nodes 145 """ 146 self.body = body 147 self.pattern = pattern 148 self.matches = None 149 self.ropevar = _RopeVariable() 150 self.matches_callback = does_match 151 152 def find_matches(self): 153 if self.matches is None: 154 self.matches = [] 155 ast.call_for_nodes(self.body, self._check_node, recursive=True) 156 return self.matches 157 158 def _check_node(self, node): 159 if isinstance(self.pattern, list): 160 self._check_statements(node) 161 else: 162 self._check_expression(node) 163 164 def _check_expression(self, node): 165 mapping = {} 166 if self._match_nodes(self.pattern, node, mapping): 167 self.matches.append(ExpressionMatch(node, mapping)) 168 169 def _check_statements(self, node): 170 for child in ast.get_children(node): 171 if isinstance(child, (list, tuple)): 172 self.__check_stmt_list(child) 173 174 def __check_stmt_list(self, nodes): 175 for index in range(len(nodes)): 176 if len(nodes) - index >= len(self.pattern): 177 current_stmts = nodes[index:index + len(self.pattern)] 178 mapping = {} 179 if self._match_stmts(current_stmts, mapping): 180 self.matches.append(StatementMatch(current_stmts, mapping)) 181 182 def _match_nodes(self, expected, node, mapping): 183 if isinstance(expected, ast.Name): 184 if self.ropevar.is_var(expected.id): 185 return self._match_wildcard(expected, node, mapping) 186 if not isinstance(expected, ast.AST): 187 return expected == node 188 if expected.__class__ != node.__class__: 189 return False 190 191 children1 = self._get_children(expected) 192 children2 = self._get_children(node) 193 if len(children1) != len(children2): 194 return False 195 for child1, child2 in zip(children1, children2): 196 if isinstance(child1, ast.AST): 197 if not self._match_nodes(child1, child2, mapping): 198 return False 199 elif isinstance(child1, (list, tuple)): 200 if not isinstance(child2, (list, tuple)) or \ 201 len(child1) != len(child2): 202 return False 203 for c1, c2 in zip(child1, child2): 204 if not self._match_nodes(c1, c2, mapping): 205 return False 206 else: 207 if type(child1) is not type(child2) or child1 != child2: 208 return False 209 return True 210 211 def _get_children(self, node): 212 """Return not `ast.expr_context` children of `node`""" 213 children = ast.get_children(node) 214 return [child for child in children 215 if not isinstance(child, ast.expr_context)] 216 217 def _match_stmts(self, current_stmts, mapping): 218 if len(current_stmts) != len(self.pattern): 219 return False 220 for stmt, expected in zip(current_stmts, self.pattern): 221 if not self._match_nodes(expected, stmt, mapping): 222 return False 223 return True 224 225 def _match_wildcard(self, node1, node2, mapping): 226 name = self.ropevar.get_base(node1.id) 227 if name not in mapping: 228 if self.matches_callback(node2, name): 229 mapping[name] = node2 230 return True 231 return False 232 else: 233 return self._match_nodes(mapping[name], node2, {}) 234 235 236class Match(object): 237 238 def __init__(self, mapping): 239 self.mapping = mapping 240 241 def get_region(self): 242 """Returns match region""" 243 244 def get_ast(self, name): 245 """Return the ast node that has matched rope variables""" 246 return self.mapping.get(name, None) 247 248 249class ExpressionMatch(Match): 250 251 def __init__(self, ast, mapping): 252 super(ExpressionMatch, self).__init__(mapping) 253 self.ast = ast 254 255 def get_region(self): 256 return self.ast.region 257 258 259class StatementMatch(Match): 260 261 def __init__(self, ast_list, mapping): 262 super(StatementMatch, self).__init__(mapping) 263 self.ast_list = ast_list 264 265 def get_region(self): 266 return self.ast_list[0].region[0], self.ast_list[-1].region[1] 267 268 269class CodeTemplate(object): 270 271 def __init__(self, template): 272 self.template = template 273 self._find_names() 274 275 def _find_names(self): 276 self.names = {} 277 for match in CodeTemplate._get_pattern().finditer(self.template): 278 if 'name' in match.groupdict() and \ 279 match.group('name') is not None: 280 start, end = match.span('name') 281 name = self.template[start + 2:end - 1] 282 if name not in self.names: 283 self.names[name] = [] 284 self.names[name].append((start, end)) 285 286 def get_names(self): 287 return self.names.keys() 288 289 def substitute(self, mapping): 290 collector = codeanalyze.ChangeCollector(self.template) 291 for name, occurrences in self.names.items(): 292 for region in occurrences: 293 collector.add_change(region[0], region[1], mapping[name]) 294 result = collector.get_changed() 295 if result is None: 296 return self.template 297 return result 298 299 _match_pattern = None 300 301 @classmethod 302 def _get_pattern(cls): 303 if cls._match_pattern is None: 304 pattern = codeanalyze.get_comment_pattern() + '|' + \ 305 codeanalyze.get_string_pattern() + '|' + \ 306 r'(?P<name>\$\{[^\s\$\}]*\})' 307 cls._match_pattern = re.compile(pattern) 308 return cls._match_pattern 309 310 311class _RopeVariable(object): 312 """Transform and identify rope inserted wildcards""" 313 314 _normal_prefix = '__rope__variable_normal_' 315 _any_prefix = '__rope__variable_any_' 316 317 def get_var(self, name): 318 if name.startswith('?'): 319 return self._get_any(name) 320 else: 321 return self._get_normal(name) 322 323 def is_var(self, name): 324 return self._is_normal(name) or self._is_var(name) 325 326 def get_base(self, name): 327 if self._is_normal(name): 328 return name[len(self._normal_prefix):] 329 if self._is_var(name): 330 return '?' + name[len(self._any_prefix):] 331 332 def _get_normal(self, name): 333 return self._normal_prefix + name 334 335 def _get_any(self, name): 336 return self._any_prefix + name[1:] 337 338 def _is_normal(self, name): 339 return name.startswith(self._normal_prefix) 340 341 def _is_var(self, name): 342 return name.startswith(self._any_prefix) 343 344 345def make_pattern(code, variables): 346 variables = set(variables) 347 collector = codeanalyze.ChangeCollector(code) 348 349 def does_match(node, name): 350 return isinstance(node, ast.Name) and node.id == name 351 finder = RawSimilarFinder(code, does_match=does_match) 352 for variable in variables: 353 for match in finder.get_matches('${%s}' % variable): 354 start, end = match.get_region() 355 collector.add_change(start, end, '${%s}' % variable) 356 result = collector.get_changed() 357 return result if result is not None else code 358 359 360def _pydefined_to_str(pydefined): 361 address = [] 362 if isinstance(pydefined, 363 (builtins.BuiltinClass, builtins.BuiltinFunction)): 364 return '__builtins__.' + pydefined.get_name() 365 else: 366 while pydefined.parent is not None: 367 address.insert(0, pydefined.get_name()) 368 pydefined = pydefined.parent 369 module_name = libutils.modname(pydefined.resource) 370 return '.'.join(module_name.split('.') + address) 371